""" adversarial_detector.py ======================== Detects adversarial / anomalous inputs that may be crafted to manipulate AI models or evade safety filters. Detection layers (all zero-dependency except the optional embedding layer): 1. Token-length analysis — unusually long or repetitive prompts 2. Character distribution — abnormal char class ratios (unicode tricks, homoglyphs) 3. Repetition detection — token/n-gram flooding 4. Encoding obfuscation — base64 blobs, hex strings, ROT-13 traces 5. Statistical anomaly — entropy, symbol density, whitespace abuse 6. Embedding outlier — cosine distance from "normal" centroid (optional) """ from __future__ import annotations import re import math import time import unicodedata import logging from collections import Counter from dataclasses import dataclass, field from typing import List, Optional logger = logging.getLogger("ai_firewall.adversarial_detector") # --------------------------------------------------------------------------- # Config defaults (tunable without subclassing) # --------------------------------------------------------------------------- DEFAULT_CONFIG = { "max_token_length": 4096, # chars (rough token proxy) "max_word_count": 800, "max_line_count": 200, "repetition_threshold": 0.45, # fraction of repeated trigrams → adversarial "entropy_min": 2.5, # too-low entropy = repetitive junk "entropy_max": 5.8, # too-high entropy = encoded/random content "symbol_density_max": 0.35, # fraction of non-alphanumeric chars "unicode_escape_threshold": 5, # count of \uXXXX / \xXX sequences "base64_min_length": 40, # minimum length of candidate b64 blocks "homoglyph_threshold": 3, # count of confusable lookalike chars } # Homoglyph mapping (Cyrillic / Greek / other confusable lookalikes for latin) _HOMOGLYPH_MAP = { "а": "a", "е": "e", "і": "i", "о": "o", "р": "p", "с": "c", "х": "x", "у": "y", "ѕ": "s", "ј": "j", "ԁ": "d", "ɡ": "g", "ʜ": "h", "ᴛ": "t", "ᴡ": "w", "ᴍ": "m", "ᴋ": "k", "α": "a", "ε": "e", "ο": "o", "ρ": "p", "ν": "v", "κ": "k", } _BASE64_RE = re.compile(r"(?:[A-Za-z0-9+/]{4}){10,}(?:[A-Za-z0-9+/]{2}==|[A-Za-z0-9+/]{3}=)?") _HEX_RE = re.compile(r"(?:0x)?[0-9a-fA-F]{16,}") _UNICODE_ESC_RE = re.compile(r"(\\u[0-9a-fA-F]{4}|\\x[0-9a-fA-F]{2}|%[0-9a-fA-F]{2})") @dataclass class AdversarialResult: is_adversarial: bool risk_score: float # 0.0 – 1.0 flags: List[str] = field(default_factory=list) details: dict = field(default_factory=dict) latency_ms: float = 0.0 def to_dict(self) -> dict: return { "is_adversarial": self.is_adversarial, "risk_score": round(self.risk_score, 4), "flags": self.flags, "details": self.details, "latency_ms": round(self.latency_ms, 2), } class AdversarialDetector: """ Stateless adversarial input detector. A prompt is considered adversarial if its aggregate risk score exceeds `threshold` (default 0.55). Parameters ---------- threshold : float Risk score above which input is flagged. config : dict, optional Override any key from DEFAULT_CONFIG. use_embeddings : bool Enable embedding-outlier detection (requires sentence-transformers). embedding_model : str Model name for the embedding layer. """ def __init__( self, threshold: float = 0.55, config: Optional[dict] = None, use_embeddings: bool = False, embedding_model: str = "all-MiniLM-L6-v2", ) -> None: self.threshold = threshold self.cfg = {**DEFAULT_CONFIG, **(config or {})} self.use_embeddings = use_embeddings self._embedder = None self._normal_centroid = None # set via `fit_normal_distribution` if use_embeddings: self._load_embedder(embedding_model) # ------------------------------------------------------------------ # Embedding layer # ------------------------------------------------------------------ def _load_embedder(self, model_name: str) -> None: try: from sentence_transformers import SentenceTransformer import numpy as np self._embedder = SentenceTransformer(model_name) logger.info("Adversarial embedding layer loaded: %s", model_name) except ImportError: logger.warning("sentence-transformers not installed — embedding outlier layer disabled.") self.use_embeddings = False def fit_normal_distribution(self, normal_prompts: List[str]) -> None: """ Compute the centroid of embedding vectors for a set of known-good prompts. Call this once at startup with representative benign prompts. """ if not self.use_embeddings or self._embedder is None: return import numpy as np embeddings = self._embedder.encode(normal_prompts, convert_to_numpy=True, normalize_embeddings=True) self._normal_centroid = embeddings.mean(axis=0) self._normal_centroid /= np.linalg.norm(self._normal_centroid) logger.info("Normal centroid computed from %d prompts.", len(normal_prompts)) # ------------------------------------------------------------------ # Individual checks # ------------------------------------------------------------------ def _check_length(self, text: str) -> tuple[float, str, dict]: char_len = len(text) word_count = len(text.split()) line_count = text.count("\n") score = 0.0 details, flags = {}, [] if char_len > self.cfg["max_token_length"]: score += 0.4 flags.append("excessive_length") if word_count > self.cfg["max_word_count"]: score += 0.25 flags.append("excessive_word_count") if line_count > self.cfg["max_line_count"]: score += 0.2 flags.append("excessive_line_count") details = {"char_len": char_len, "word_count": word_count, "line_count": line_count} return min(score, 1.0), "|".join(flags), details def _check_repetition(self, text: str) -> tuple[float, str, dict]: words = text.lower().split() if len(words) < 6: return 0.0, "", {} trigrams = [tuple(words[i:i+3]) for i in range(len(words) - 2)] if not trigrams: return 0.0, "", {} total = len(trigrams) unique = len(set(trigrams)) repetition_ratio = 1.0 - (unique / total) score = 0.0 flag = "" if repetition_ratio >= self.cfg["repetition_threshold"]: score = min(repetition_ratio, 1.0) flag = "high_token_repetition" return score, flag, {"repetition_ratio": round(repetition_ratio, 3)} def _check_entropy(self, text: str) -> tuple[float, str, dict]: if not text: return 0.0, "", {} freq = Counter(text) total = len(text) entropy = -sum((c / total) * math.log2(c / total) for c in freq.values()) score = 0.0 flag = "" if entropy < self.cfg["entropy_min"]: score = 0.5 flag = "low_entropy_repetitive" elif entropy > self.cfg["entropy_max"]: score = 0.6 flag = "high_entropy_possibly_encoded" return score, flag, {"entropy": round(entropy, 3)} def _check_symbol_density(self, text: str) -> tuple[float, str, dict]: if not text: return 0.0, "", {} non_alnum = sum(1 for c in text if not c.isalnum() and not c.isspace()) density = non_alnum / len(text) score = 0.0 flag = "" if density > self.cfg["symbol_density_max"]: score = min(density, 1.0) flag = "high_symbol_density" return score, flag, {"symbol_density": round(density, 3)} def _check_encoding_obfuscation(self, text: str) -> tuple[float, str, dict]: score = 0.0 flags = [] details = {} # Unicode escape sequences esc_matches = _UNICODE_ESC_RE.findall(text) if len(esc_matches) >= self.cfg["unicode_escape_threshold"]: score += 0.5 flags.append("unicode_escape_sequences") details["unicode_escapes"] = len(esc_matches) # Base64-like blobs b64_matches = _BASE64_RE.findall(text) if b64_matches: score += 0.4 flags.append("base64_like_content") details["base64_blocks"] = len(b64_matches) # Long hex strings hex_matches = _HEX_RE.findall(text) if hex_matches: score += 0.3 flags.append("hex_encoded_content") details["hex_blocks"] = len(hex_matches) return min(score, 1.0), "|".join(flags), details def _check_homoglyphs(self, text: str) -> tuple[float, str, dict]: count = sum(1 for ch in text if ch in _HOMOGLYPH_MAP) score = 0.0 flag = "" if count >= self.cfg["homoglyph_threshold"]: score = min(count / 20, 1.0) flag = "homoglyph_substitution" return score, flag, {"homoglyph_count": count} def _check_unicode_normalization(self, text: str) -> tuple[float, str, dict]: """Detect invisible / zero-width / direction-override characters.""" bad_categories = {"Cf", "Cs", "Co"} # format, surrogate, private-use bad_chars = [c for c in text if unicodedata.category(c) in bad_categories] score = 0.0 flag = "" if len(bad_chars) > 2: score = min(len(bad_chars) / 10, 1.0) flag = "invisible_unicode_chars" return score, flag, {"invisible_char_count": len(bad_chars)} def _check_embedding_outlier(self, text: str) -> tuple[float, str, dict]: if not self.use_embeddings or self._embedder is None or self._normal_centroid is None: return 0.0, "", {} try: import numpy as np emb = self._embedder.encode(text, convert_to_numpy=True, normalize_embeddings=True) similarity = float(emb @ self._normal_centroid) distance = 1.0 - similarity # 0 = identical to normal, 1 = orthogonal score = max(0.0, (distance - 0.3) / 0.7) # linear rescale [0.3, 1.0] → [0, 1] flag = "embedding_outlier" if score > 0.3 else "" return score, flag, {"centroid_distance": round(distance, 4)} except Exception as exc: logger.debug("Embedding outlier check failed: %s", exc) return 0.0, "", {} # ------------------------------------------------------------------ # Aggregation # ------------------------------------------------------------------ def detect(self, text: str) -> AdversarialResult: """ Run all detection layers and return an AdversarialResult. Parameters ---------- text : str Raw user prompt. """ t0 = time.perf_counter() checks = [ self._check_length(text), self._check_repetition(text), self._check_entropy(text), self._check_symbol_density(text), self._check_encoding_obfuscation(text), self._check_homoglyphs(text), self._check_unicode_normalization(text), self._check_embedding_outlier(text), ] aggregate_score = 0.0 all_flags: List[str] = [] all_details: dict = {} weights = [0.15, 0.20, 0.15, 0.10, 0.20, 0.10, 0.10, 0.20] # sum > 1 ok; normalised below weight_sum = sum(weights) for (score, flag, details), weight in zip(checks, weights): aggregate_score += score * weight if flag: all_flags.extend(flag.split("|")) all_details.update(details) risk_score = min(aggregate_score / weight_sum, 1.0) is_adversarial = risk_score >= self.threshold latency = (time.perf_counter() - t0) * 1000 result = AdversarialResult( is_adversarial=is_adversarial, risk_score=risk_score, flags=list(filter(None, all_flags)), details=all_details, latency_ms=latency, ) if is_adversarial: logger.warning("Adversarial input detected | score=%.3f flags=%s", risk_score, all_flags) return result def is_safe(self, text: str) -> bool: return not self.detect(text).is_adversarial