File size: 12,707 Bytes
4afcb3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
"""
adversarial_detector.py
========================
Detects adversarial / anomalous inputs that may be crafted to manipulate
AI models or evade safety filters.

Detection layers (all zero-dependency except the optional embedding layer):
  1. Token-length analysis     — unusually long or repetitive prompts
  2. Character distribution    — abnormal char class ratios (unicode tricks, homoglyphs)
  3. Repetition detection      — token/n-gram flooding
  4. Encoding obfuscation      — base64 blobs, hex strings, ROT-13 traces
  5. Statistical anomaly       — entropy, symbol density, whitespace abuse
  6. Embedding outlier         — cosine distance from "normal" centroid (optional)
"""

from __future__ import annotations

import re
import math
import time
import unicodedata
import logging
from collections import Counter
from dataclasses import dataclass, field
from typing import List, Optional

logger = logging.getLogger("ai_firewall.adversarial_detector")


# ---------------------------------------------------------------------------
# Config defaults (tunable without subclassing)
# ---------------------------------------------------------------------------

DEFAULT_CONFIG = {
    "max_token_length": 4096,      # chars (rough token proxy)
    "max_word_count": 800,
    "max_line_count": 200,
    "repetition_threshold": 0.45,  # fraction of repeated trigrams → adversarial
    "entropy_min": 2.5,            # too-low entropy = repetitive junk
    "entropy_max": 5.8,            # too-high entropy = encoded/random content
    "symbol_density_max": 0.35,    # fraction of non-alphanumeric chars
    "unicode_escape_threshold": 5, # count of \uXXXX / \xXX sequences
    "base64_min_length": 40,       # minimum length of candidate b64 blocks
    "homoglyph_threshold": 3,      # count of confusable lookalike chars
}

# Homoglyph mapping (Cyrillic / Greek / other confusable lookalikes for latin)
_HOMOGLYPH_MAP = {
    "а": "a", "е": "e", "і": "i", "о": "o", "р": "p", "с": "c",
    "х": "x", "у": "y", "ѕ": "s", "ј": "j", "ԁ": "d", "ɡ": "g",
    "ʜ": "h", "ᴛ": "t", "ᴡ": "w", "ᴍ": "m", "ᴋ": "k",
    "α": "a", "ε": "e", "ο": "o", "ρ": "p", "ν": "v", "κ": "k",
}

_BASE64_RE = re.compile(r"(?:[A-Za-z0-9+/]{4}){10,}(?:[A-Za-z0-9+/]{2}==|[A-Za-z0-9+/]{3}=)?")
_HEX_RE    = re.compile(r"(?:0x)?[0-9a-fA-F]{16,}")
_UNICODE_ESC_RE = re.compile(r"(\\u[0-9a-fA-F]{4}|\\x[0-9a-fA-F]{2}|%[0-9a-fA-F]{2})")


@dataclass
class AdversarialResult:
    is_adversarial: bool
    risk_score: float                  # 0.0 – 1.0
    flags: List[str] = field(default_factory=list)
    details: dict = field(default_factory=dict)
    latency_ms: float = 0.0

    def to_dict(self) -> dict:
        return {
            "is_adversarial": self.is_adversarial,
            "risk_score": round(self.risk_score, 4),
            "flags": self.flags,
            "details": self.details,
            "latency_ms": round(self.latency_ms, 2),
        }


class AdversarialDetector:
    """
    Stateless adversarial input detector.

    A prompt is considered adversarial if its aggregate risk score
    exceeds `threshold` (default 0.55).

    Parameters
    ----------
    threshold : float
        Risk score above which input is flagged.
    config : dict, optional
        Override any key from DEFAULT_CONFIG.
    use_embeddings : bool
        Enable embedding-outlier detection (requires sentence-transformers).
    embedding_model : str
        Model name for the embedding layer.
    """

    def __init__(
        self,
        threshold: float = 0.55,
        config: Optional[dict] = None,
        use_embeddings: bool = False,
        embedding_model: str = "all-MiniLM-L6-v2",
    ) -> None:
        self.threshold = threshold
        self.cfg = {**DEFAULT_CONFIG, **(config or {})}
        self.use_embeddings = use_embeddings
        self._embedder = None
        self._normal_centroid = None  # set via `fit_normal_distribution`

        if use_embeddings:
            self._load_embedder(embedding_model)

    # ------------------------------------------------------------------
    # Embedding layer
    # ------------------------------------------------------------------

    def _load_embedder(self, model_name: str) -> None:
        try:
            from sentence_transformers import SentenceTransformer
            import numpy as np
            self._embedder = SentenceTransformer(model_name)
            logger.info("Adversarial embedding layer loaded: %s", model_name)
        except ImportError:
            logger.warning("sentence-transformers not installed — embedding outlier layer disabled.")
            self.use_embeddings = False

    def fit_normal_distribution(self, normal_prompts: List[str]) -> None:
        """
        Compute the centroid of embedding vectors for a set of known-good
        prompts.  Call this once at startup with representative benign prompts.
        """
        if not self.use_embeddings or self._embedder is None:
            return
        import numpy as np
        embeddings = self._embedder.encode(normal_prompts, convert_to_numpy=True, normalize_embeddings=True)
        self._normal_centroid = embeddings.mean(axis=0)
        self._normal_centroid /= np.linalg.norm(self._normal_centroid)
        logger.info("Normal centroid computed from %d prompts.", len(normal_prompts))

    # ------------------------------------------------------------------
    # Individual checks
    # ------------------------------------------------------------------

    def _check_length(self, text: str) -> tuple[float, str, dict]:
        char_len = len(text)
        word_count = len(text.split())
        line_count = text.count("\n")
        score = 0.0
        details, flags = {}, []

        if char_len > self.cfg["max_token_length"]:
            score += 0.4
            flags.append("excessive_length")
        if word_count > self.cfg["max_word_count"]:
            score += 0.25
            flags.append("excessive_word_count")
        if line_count > self.cfg["max_line_count"]:
            score += 0.2
            flags.append("excessive_line_count")

        details = {"char_len": char_len, "word_count": word_count, "line_count": line_count}
        return min(score, 1.0), "|".join(flags), details

    def _check_repetition(self, text: str) -> tuple[float, str, dict]:
        words = text.lower().split()
        if len(words) < 6:
            return 0.0, "", {}
        trigrams = [tuple(words[i:i+3]) for i in range(len(words) - 2)]
        if not trigrams:
            return 0.0, "", {}
        total = len(trigrams)
        unique = len(set(trigrams))
        repetition_ratio = 1.0 - (unique / total)
        score = 0.0
        flag = ""
        if repetition_ratio >= self.cfg["repetition_threshold"]:
            score = min(repetition_ratio, 1.0)
            flag = "high_token_repetition"
        return score, flag, {"repetition_ratio": round(repetition_ratio, 3)}

    def _check_entropy(self, text: str) -> tuple[float, str, dict]:
        if not text:
            return 0.0, "", {}
        freq = Counter(text)
        total = len(text)
        entropy = -sum((c / total) * math.log2(c / total) for c in freq.values())
        score = 0.0
        flag = ""
        if entropy < self.cfg["entropy_min"]:
            score = 0.5
            flag = "low_entropy_repetitive"
        elif entropy > self.cfg["entropy_max"]:
            score = 0.6
            flag = "high_entropy_possibly_encoded"
        return score, flag, {"entropy": round(entropy, 3)}

    def _check_symbol_density(self, text: str) -> tuple[float, str, dict]:
        if not text:
            return 0.0, "", {}
        non_alnum = sum(1 for c in text if not c.isalnum() and not c.isspace())
        density = non_alnum / len(text)
        score = 0.0
        flag = ""
        if density > self.cfg["symbol_density_max"]:
            score = min(density, 1.0)
            flag = "high_symbol_density"
        return score, flag, {"symbol_density": round(density, 3)}

    def _check_encoding_obfuscation(self, text: str) -> tuple[float, str, dict]:
        score = 0.0
        flags = []
        details = {}

        # Unicode escape sequences
        esc_matches = _UNICODE_ESC_RE.findall(text)
        if len(esc_matches) >= self.cfg["unicode_escape_threshold"]:
            score += 0.5
            flags.append("unicode_escape_sequences")
            details["unicode_escapes"] = len(esc_matches)

        # Base64-like blobs
        b64_matches = _BASE64_RE.findall(text)
        if b64_matches:
            score += 0.4
            flags.append("base64_like_content")
            details["base64_blocks"] = len(b64_matches)

        # Long hex strings
        hex_matches = _HEX_RE.findall(text)
        if hex_matches:
            score += 0.3
            flags.append("hex_encoded_content")
            details["hex_blocks"] = len(hex_matches)

        return min(score, 1.0), "|".join(flags), details

    def _check_homoglyphs(self, text: str) -> tuple[float, str, dict]:
        count = sum(1 for ch in text if ch in _HOMOGLYPH_MAP)
        score = 0.0
        flag = ""
        if count >= self.cfg["homoglyph_threshold"]:
            score = min(count / 20, 1.0)
            flag = "homoglyph_substitution"
        return score, flag, {"homoglyph_count": count}

    def _check_unicode_normalization(self, text: str) -> tuple[float, str, dict]:
        """Detect invisible / zero-width / direction-override characters."""
        bad_categories = {"Cf", "Cs", "Co"}  # format, surrogate, private-use
        bad_chars = [c for c in text if unicodedata.category(c) in bad_categories]
        score = 0.0
        flag = ""
        if len(bad_chars) > 2:
            score = min(len(bad_chars) / 10, 1.0)
            flag = "invisible_unicode_chars"
        return score, flag, {"invisible_char_count": len(bad_chars)}

    def _check_embedding_outlier(self, text: str) -> tuple[float, str, dict]:
        if not self.use_embeddings or self._embedder is None or self._normal_centroid is None:
            return 0.0, "", {}
        try:
            import numpy as np
            emb = self._embedder.encode(text, convert_to_numpy=True, normalize_embeddings=True)
            similarity = float(emb @ self._normal_centroid)
            distance = 1.0 - similarity  # 0 = identical to normal, 1 = orthogonal
            score = max(0.0, (distance - 0.3) / 0.7)  # linear rescale [0.3, 1.0] → [0, 1]
            flag = "embedding_outlier" if score > 0.3 else ""
            return score, flag, {"centroid_distance": round(distance, 4)}
        except Exception as exc:
            logger.debug("Embedding outlier check failed: %s", exc)
            return 0.0, "", {}

    # ------------------------------------------------------------------
    # Aggregation
    # ------------------------------------------------------------------

    def detect(self, text: str) -> AdversarialResult:
        """
        Run all detection layers and return an AdversarialResult.

        Parameters
        ----------
        text : str
            Raw user prompt.
        """
        t0 = time.perf_counter()

        checks = [
            self._check_length(text),
            self._check_repetition(text),
            self._check_entropy(text),
            self._check_symbol_density(text),
            self._check_encoding_obfuscation(text),
            self._check_homoglyphs(text),
            self._check_unicode_normalization(text),
            self._check_embedding_outlier(text),
        ]

        aggregate_score = 0.0
        all_flags: List[str] = []
        all_details: dict = {}

        weights = [0.15, 0.20, 0.15, 0.10, 0.20, 0.10, 0.10, 0.20]  # sum > 1 ok; normalised below

        weight_sum = sum(weights)
        for (score, flag, details), weight in zip(checks, weights):
            aggregate_score += score * weight
            if flag:
                all_flags.extend(flag.split("|"))
            all_details.update(details)

        risk_score = min(aggregate_score / weight_sum, 1.0)
        is_adversarial = risk_score >= self.threshold

        latency = (time.perf_counter() - t0) * 1000

        result = AdversarialResult(
            is_adversarial=is_adversarial,
            risk_score=risk_score,
            flags=list(filter(None, all_flags)),
            details=all_details,
            latency_ms=latency,
        )

        if is_adversarial:
            logger.warning("Adversarial input detected | score=%.3f flags=%s", risk_score, all_flags)

        return result

    def is_safe(self, text: str) -> bool:
        return not self.detect(text).is_adversarial