File size: 14,270 Bytes
2129c29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
"""
Prompt reconstruction and token optimization module.

Reconstructs compressed prompts by re-injecting protected entities,
applying semantic filtering, and computing token/cost metrics.

Mathematical Foundations
------------------------
1. Token Counting (Byte-Pair Encoding):
    tokens(T) = |BPE_encode(T, V)|
    Reference: Sennrich et al., "Neural Machine Translation of Rare Words", ACL 2016
    https://github.com/openai/tiktoken

2. Compression Ratio:
    R = 1 - (T_compressed / T_original) ∈ [0, 1]
    Cost savings: C = (T_original - T_compressed) × price_per_token

3. Placeholder Substitution Order:
    Process placeholders by descending length to avoid partial matches.
    If |ph₁| > |ph₂| and ph₂ ⊂ ph₁, replacing ph₂ first corrupts ph₁.

4. Stopword Filtering:
    T_filtered = {w ∈ T : w ∉ S} where S = stopword set
    Implemented via regex with word boundaries: O(n) single-pass removal.

References
----------
[1] Sennrich, R., Haddow, B., & Birch, A. (2016). Neural Machine Translation 
    of Rare Words with Subword Units. ACL 2016.
    https://arxiv.org/abs/1508.07909

[2] Manning, C. D., & Schütze, H. (1999). Foundations of Statistical NLP.
    MIT Press. Chapter 2: Collocations.

[3] OpenAI. (2024). tiktoken: BPE tokenizer for GPT models.
    https://github.com/openai/tiktoken

Performance
-----------
- _filter_stopwords(): O(n + |S|) via Aho-Corasick regex engine
- _reinject_entities(): O(k·n·m) with k=placeholders, m=avg length
- reconstruct(): O(n + k·n·m + t_tokenize) where t_tokenize = BPE time
- Token counting: ~1M tokens/sec via tiktoken's C++ backend

Author: IntelliDeep Labs Team
License: BSL 1.1
"""

from __future__ import annotations

import logging
import re
from dataclasses import dataclass, field
from typing import Dict, List, Optional

import tiktoken

from nlproxy.core.shield import ShieldResult
from nlproxy.utils.constants import MODEL_PRICING, SEMANTIC_STOPWORDS

logger = logging.getLogger(__name__)


@dataclass
class ReconstructionResult:
    """
    Output container for the reconstruction pipeline.

    Attributes
    ----------
    compressed_text : str
        Final reconstructed prompt ready for LLM consumption.
    original_tokens : int
        Token count of the original prompt (BPE-encoded).
    compressed_tokens : int
        Token count after compression + reconstruction.
    tokens_saved : int
        Absolute reduction: original_tokens - compressed_tokens.
    compression_ratio : float
        Relative savings: tokens_saved / original_tokens ∈ [-∞, 1].
    cost_saved_usd : float
        Estimated USD savings: tokens_saved × price_per_input_token.
    safety_score : Optional[float]
        Safety validation score from SafetyChecker (0.0 to 1.0).
    audit_log : List[Dict]
        Step-by-step processing log from upstream ShieldResult.
    alerts : List[str]
        Warning messages for edge cases (e.g., negative compression).
    compressed_indices : List[int]
        Original sentence indices retained after compression.
    """
    compressed_text: str
    original_tokens: int
    compressed_tokens: int
    tokens_saved: int
    compression_ratio: float
    cost_saved_usd: float
    safety_score: Optional[float] = None
    audit_log: List[Dict] = field(default_factory=list)
    alerts: List[str] = field(default_factory=list)
    compressed_indices: List[int] = field(default_factory=list)



class PromptReconstructor:
    """
    Reconstructs compressed prompts with entity re-injection and metrics.

    Pipeline:
    1. Join compressed sentences → coherent text
    2. Optional stopword filtering (reduce noise tokens)
    3. Re-inject protected placeholders with original values
    4. Inject constraint instructions (FORBID/MANDATE)
    5. Normalize whitespace/punctuation
    6. Compute token counts and cost savings via tiktoken

    Key Design Decisions
    --------------------
    - Longest-first placeholder substitution prevents partial matches
    - Tolerant regex matching handles LLM output variations (case, underscores)
    - Stopword filtering applied BEFORE re-injection to avoid breaking code
    - Instruction injection only when stopwords enabled (reduces token overhead)
    """

    # Pre-compiled stopword regex (shared across instances)
    # Pattern: \b(stopword1|stopword2|...)\b[.,]?\s*
    _STOPWORD_PATTERN: re.Pattern = re.compile(
        r'\b(?:' + '|'.join(re.escape(w) for w in SEMANTIC_STOPWORDS) + r')\b[.,]?\s*',
        flags=re.IGNORECASE
    )

    # Line-start cleanup: removes leading punctuation/spaces after stopword removal
    _LINE_START_CLEANUP: re.Pattern = re.compile(r'^[ ,.;:!?]+', flags=re.MULTILINE)

    # Whitespace normalization patterns
    _MULTI_SPACE: re.Pattern = re.compile(r' +')
    _MULTI_NEWLINE: re.Pattern = re.compile(r'\n\s*\n+')


    def __init__(
        self,
        model_name: str = "gpt-4",
        custom_pricing: Optional[Dict[str, float]] = None,
        remove_stopwords: bool = True,
    ) -> None:
        """
        Initialize the PromptReconstructor.

        Parameters
        ----------
        model_name : str, optional
            Target LLM for token counting and pricing (default: "gpt-4").
            Must be tiktoken-compatible or falls back to cl100k_base.
        custom_pricing : Optional[Dict[str, float]], optional
            Override pricing: {"input": price_per_token, "output": price_per_token}.
        remove_stopwords : bool, optional
            Enable semantic stopword filtering (default: True).
            Disable for domains where all words carry meaning (e.g., legal).
        """
        self.model_name = model_name
        self.remove_stopwords = remove_stopwords

        # Resolve pricing: custom > model-specific > default
        self.pricing = (
            custom_pricing 
            or MODEL_PRICING.get(model_name) 
            or MODEL_PRICING.get("default", {"input": 3e-5, "output": 6e-5})
        )

        # Initialize tokenizer with fallback to cl100k_base (GPT-4 encoding)
        try:
            self.tokenizer = tiktoken.encoding_for_model(model_name)
        except KeyError:
            self.tokenizer = tiktoken.get_encoding("cl100k_base")
            logger.warning(f"Model '{model_name}' not in tiktoken; using cl100k_base")

        logger.info(f"PromptReconstructor initialized for '{model_name}'")


    def _filter_stopwords(self, text: str) -> str:
        """
        Remove semantic stopwords and clean resulting artifacts.

        Two-stage cleanup:
        1. Remove stopwords with optional trailing punctuation/whitespace
        2. Clean line-start artifacts (leading commas, periods, spaces)

        Complexity: O(n + |S|) via single-pass regex substitution.
        """
        # Stage 1: Remove stopwords with trailing punctuation/whitespace
        cleaned = self._STOPWORD_PATTERN.sub('', text)

        # Stage 2: Clean line-start artifacts from stopword removal
        lines = cleaned.splitlines()
        cleaned_lines = [self._LINE_START_CLEANUP.sub('', line) for line in lines]
        
        return '\n'.join(cleaned_lines)


    def _get_placeholder_pattern(self, placeholder: str) -> re.Pattern:
        """
        Build tolerant regex pattern for placeholder matching.

        Allows optional underscores and case-insensitive matching
        to handle LLM output variations (e.g., "__PROT_abc" vs "_PROT_abc").

        Pattern: re.escape(placeholder).replace(r'\_', r'[_]?')
        Compiled with re.IGNORECASE for case-insensitive matching.
        """
        escaped = re.escape(placeholder)
        # Make underscores optional: _ → [_]?
        tolerant_pattern = escaped.replace('_', '[_]?')
        return re.compile(tolerant_pattern, flags=re.IGNORECASE)


    def _reinject_entities(self, text: str, placeholder_map: Dict[str, str]) -> str:
        """
        Re-inject protected entity values by replacing placeholders.

        Uses longest-first substitution to avoid partial matches:
        - Sort placeholders by length descending
        - Replace longer placeholders first to prevent corruption

        Example:
            placeholder_map = {"__PROT_ab": "A", "__PROT_abc": "B"}
            text = "__PROT_abc"
            If "__PROT_ab" replaced first: "__PROT_abc" → "Ac" (wrong!)
            If "__PROT_abc" replaced first: "__PROT_abc" → "B" (correct ✓)

        Complexity: O(k·n·m) where k=placeholders, n=text length, m=avg placeholder length.
        """
        # Sort by length descending to avoid partial matches
        sorted_placeholders = sorted(placeholder_map.keys(), key=len, reverse=True)

        result = text
        for placeholder in sorted_placeholders:
            value = placeholder_map[placeholder]
            pattern = self._get_placeholder_pattern(placeholder)
            result = pattern.sub(value, result)

        return result


    def inject_instructions(self, compressed_text: str, shield_result: ShieldResult) -> str:
        """
        Inject critical instructions based on extracted restrictions.

        Adds a prefixed instruction block guiding the LLM to respect
        FORBID/MANDATE constraints and preserve protected placeholders.

        Format:
            [CRITICAL INSTRUCTIONS]
            Do not use: Python, Java.
            Must use: Rust.
            Placeholders __PROT_xxx contain protected data; do not modify.
            [/CRITICAL INSTRUCTIONS]

            Original compressed text...
        """
        instructions: List[str] = []

        if shield_result.restrictions:
            forbidden = [r.entity for r in shield_result.restrictions if r.type == "FORBID"]
            mandated = [r.entity for r in shield_result.restrictions if r.type == "MANDATE"]
            
            if forbidden:
                instructions.append(f"Do not use: {', '.join(forbidden)}.")
            if mandated:
                instructions.append(f"Must use: {', '.join(mandated)}.")

        if shield_result.placeholder_map:
            instructions.append(
                "Placeholders __PROT_xxx contain protected data; do not modify."
            )

        if instructions:
            instruction_block = (
                "[CRITICAL INSTRUCTIONS]\n"
                + "\n".join(instructions)
                + "\n[/CRITICAL INSTRUCTIONS]\n\n"
            )
            return instruction_block + compressed_text
        
        return compressed_text


    def reconstruct(
        self,
        original_prompt: str,
        compressed_sentences: List[str],
        shield_result: ShieldResult,
        apply_stopwords: Optional[bool] = None,
        compressed_indices: Optional[List[int]] = None,
        privacy_mode: bool = False,
    ) -> ReconstructionResult:
        """
        Execute the complete reconstruction pipeline.

        Parameters
        ----------
        original_prompt : str
            Original uncompressed prompt (for token comparison).
        compressed_sentences : List[str]
            Sentences after semantic compression.
        shield_result : ShieldResult
            Result from PromptShield with placeholders and restrictions.
        apply_stopwords : Optional[bool], optional
            Override instance default for stopword filtering.
        compressed_indices : Optional[List[int]], optional
            Original indices of compressed sentences for traceability.
        privacy_mode : bool, optional
            If True, suppresses entity re-injection to keep sensitive data masked.

        Returns
        -------
        ReconstructionResult
            Container with reconstructed text, metrics, and audit info.

        Pipeline Stages
        ---------------
        1. Join compressed sentences with newlines
        2. Optional stopword filtering (reduces token count)
        3. Re-inject protected entities via placeholder_map
        4. Inject constraint instructions (if stopwords enabled)
        5. Whitespace and punctuation normalization
        6. Token metrics and cost calculations

        Complexity: O(n + k·n·m + t_tokenize) where t_tokenize = BPE encoding time.
        """
        apply_stopwords = apply_stopwords if apply_stopwords is not None else self.remove_stopwords

        # Stage 1: Join compressed sentences
        compressed_text = "\n".join(compressed_sentences)

        # Stage 2: Stopword filtering (BEFORE re-injection to avoid breaking code)
        if apply_stopwords:
            compressed_text = self._filter_stopwords(compressed_text)

        # Stage 3: Re-inject protected entities (unless privacy mode is enabled)
        if not privacy_mode:
            final_text = self._reinject_entities(compressed_text, shield_result.placeholder_map)
        else:
            final_text = compressed_text

        # Stage 4: Inject instructions (only if stopwords enabled)
        if apply_stopwords:
            final_text = self.inject_instructions(final_text, shield_result)

        # Stage 5: Normalize whitespace and punctuation
        final_text = self._MULTI_SPACE.sub(' ', final_text)
        final_text = self._LINE_START_CLEANUP.sub('', final_text)
        final_text = self._MULTI_NEWLINE.sub('\n', final_text).strip()

        # Stage 6: Compute metrics
        final_indices = sorted(compressed_indices) if compressed_indices else list(range(len(compressed_sentences)))

        orig_tokens = len(self.tokenizer.encode(original_prompt))
        comp_tokens = len(self.tokenizer.encode(final_text))
        saved = orig_tokens - comp_tokens
        ratio = saved / orig_tokens if orig_tokens else 0.0
        cost = saved * self.pricing["input"]

        alerts = []
        if saved < 0:
            alerts.append("Negative compression detected.")
        if ratio > 0.9:
            alerts.append("Aggressive compression may affect quality.")

        return ReconstructionResult(
            compressed_text=final_text,
            original_tokens=orig_tokens,
            compressed_tokens=comp_tokens,
            tokens_saved=saved,
            compression_ratio=ratio,
            cost_saved_usd=cost,
            audit_log=shield_result.audit_log.copy(),
            alerts=alerts,
            compressed_indices=final_indices
        )