Spaces:
Sleeping
Sleeping
| """ | |
| adversarial_detector.py | |
| ======================== | |
| Detects adversarial / anomalous inputs that may be crafted to manipulate | |
| AI models or evade safety filters. | |
| Detection layers (all zero-dependency except the optional embedding layer): | |
| 1. Token-length analysis — unusually long or repetitive prompts | |
| 2. Character distribution — abnormal char class ratios (unicode tricks, homoglyphs) | |
| 3. Repetition detection — token/n-gram flooding | |
| 4. Encoding obfuscation — base64 blobs, hex strings, ROT-13 traces | |
| 5. Statistical anomaly — entropy, symbol density, whitespace abuse | |
| 6. Embedding outlier — cosine distance from "normal" centroid (optional) | |
| """ | |
| from __future__ import annotations | |
| import re | |
| import math | |
| import time | |
| import unicodedata | |
| import logging | |
| from collections import Counter | |
| from dataclasses import dataclass, field | |
| from typing import List, Optional | |
| logger = logging.getLogger("ai_firewall.adversarial_detector") | |
| # --------------------------------------------------------------------------- | |
| # Config defaults (tunable without subclassing) | |
| # --------------------------------------------------------------------------- | |
| DEFAULT_CONFIG = { | |
| "max_token_length": 4096, # chars (rough token proxy) | |
| "max_word_count": 800, | |
| "max_line_count": 200, | |
| "repetition_threshold": 0.45, # fraction of repeated trigrams → adversarial | |
| "entropy_min": 2.5, # too-low entropy = repetitive junk | |
| "entropy_max": 5.8, # too-high entropy = encoded/random content | |
| "symbol_density_max": 0.35, # fraction of non-alphanumeric chars | |
| "unicode_escape_threshold": 5, # count of \uXXXX / \xXX sequences | |
| "base64_min_length": 40, # minimum length of candidate b64 blocks | |
| "homoglyph_threshold": 3, # count of confusable lookalike chars | |
| } | |
| # Homoglyph mapping (Cyrillic / Greek / other confusable lookalikes for latin) | |
| _HOMOGLYPH_MAP = { | |
| "а": "a", "е": "e", "і": "i", "о": "o", "р": "p", "с": "c", | |
| "х": "x", "у": "y", "ѕ": "s", "ј": "j", "ԁ": "d", "ɡ": "g", | |
| "ʜ": "h", "ᴛ": "t", "ᴡ": "w", "ᴍ": "m", "ᴋ": "k", | |
| "α": "a", "ε": "e", "ο": "o", "ρ": "p", "ν": "v", "κ": "k", | |
| } | |
| _BASE64_RE = re.compile(r"(?:[A-Za-z0-9+/]{4}){10,}(?:[A-Za-z0-9+/]{2}==|[A-Za-z0-9+/]{3}=)?") | |
| _HEX_RE = re.compile(r"(?:0x)?[0-9a-fA-F]{16,}") | |
| _UNICODE_ESC_RE = re.compile(r"(\\u[0-9a-fA-F]{4}|\\x[0-9a-fA-F]{2}|%[0-9a-fA-F]{2})") | |
| class AdversarialResult: | |
| is_adversarial: bool | |
| risk_score: float # 0.0 – 1.0 | |
| flags: List[str] = field(default_factory=list) | |
| details: dict = field(default_factory=dict) | |
| latency_ms: float = 0.0 | |
| def to_dict(self) -> dict: | |
| return { | |
| "is_adversarial": self.is_adversarial, | |
| "risk_score": round(self.risk_score, 4), | |
| "flags": self.flags, | |
| "details": self.details, | |
| "latency_ms": round(self.latency_ms, 2), | |
| } | |
| class AdversarialDetector: | |
| """ | |
| Stateless adversarial input detector. | |
| A prompt is considered adversarial if its aggregate risk score | |
| exceeds `threshold` (default 0.55). | |
| Parameters | |
| ---------- | |
| threshold : float | |
| Risk score above which input is flagged. | |
| config : dict, optional | |
| Override any key from DEFAULT_CONFIG. | |
| use_embeddings : bool | |
| Enable embedding-outlier detection (requires sentence-transformers). | |
| embedding_model : str | |
| Model name for the embedding layer. | |
| """ | |
| def __init__( | |
| self, | |
| threshold: float = 0.55, | |
| config: Optional[dict] = None, | |
| use_embeddings: bool = False, | |
| embedding_model: str = "all-MiniLM-L6-v2", | |
| ) -> None: | |
| self.threshold = threshold | |
| self.cfg = {**DEFAULT_CONFIG, **(config or {})} | |
| self.use_embeddings = use_embeddings | |
| self._embedder = None | |
| self._normal_centroid = None # set via `fit_normal_distribution` | |
| if use_embeddings: | |
| self._load_embedder(embedding_model) | |
| # ------------------------------------------------------------------ | |
| # Embedding layer | |
| # ------------------------------------------------------------------ | |
| def _load_embedder(self, model_name: str) -> None: | |
| try: | |
| from sentence_transformers import SentenceTransformer | |
| import numpy as np | |
| self._embedder = SentenceTransformer(model_name) | |
| logger.info("Adversarial embedding layer loaded: %s", model_name) | |
| except ImportError: | |
| logger.warning("sentence-transformers not installed — embedding outlier layer disabled.") | |
| self.use_embeddings = False | |
| def fit_normal_distribution(self, normal_prompts: List[str]) -> None: | |
| """ | |
| Compute the centroid of embedding vectors for a set of known-good | |
| prompts. Call this once at startup with representative benign prompts. | |
| """ | |
| if not self.use_embeddings or self._embedder is None: | |
| return | |
| import numpy as np | |
| embeddings = self._embedder.encode(normal_prompts, convert_to_numpy=True, normalize_embeddings=True) | |
| self._normal_centroid = embeddings.mean(axis=0) | |
| self._normal_centroid /= np.linalg.norm(self._normal_centroid) | |
| logger.info("Normal centroid computed from %d prompts.", len(normal_prompts)) | |
| # ------------------------------------------------------------------ | |
| # Individual checks | |
| # ------------------------------------------------------------------ | |
| def _check_length(self, text: str) -> tuple[float, str, dict]: | |
| char_len = len(text) | |
| word_count = len(text.split()) | |
| line_count = text.count("\n") | |
| score = 0.0 | |
| details, flags = {}, [] | |
| if char_len > self.cfg["max_token_length"]: | |
| score += 0.4 | |
| flags.append("excessive_length") | |
| if word_count > self.cfg["max_word_count"]: | |
| score += 0.25 | |
| flags.append("excessive_word_count") | |
| if line_count > self.cfg["max_line_count"]: | |
| score += 0.2 | |
| flags.append("excessive_line_count") | |
| details = {"char_len": char_len, "word_count": word_count, "line_count": line_count} | |
| return min(score, 1.0), "|".join(flags), details | |
| def _check_repetition(self, text: str) -> tuple[float, str, dict]: | |
| words = text.lower().split() | |
| if len(words) < 6: | |
| return 0.0, "", {} | |
| trigrams = [tuple(words[i:i+3]) for i in range(len(words) - 2)] | |
| if not trigrams: | |
| return 0.0, "", {} | |
| total = len(trigrams) | |
| unique = len(set(trigrams)) | |
| repetition_ratio = 1.0 - (unique / total) | |
| score = 0.0 | |
| flag = "" | |
| if repetition_ratio >= self.cfg["repetition_threshold"]: | |
| score = min(repetition_ratio, 1.0) | |
| flag = "high_token_repetition" | |
| return score, flag, {"repetition_ratio": round(repetition_ratio, 3)} | |
| def _check_entropy(self, text: str) -> tuple[float, str, dict]: | |
| if not text: | |
| return 0.0, "", {} | |
| freq = Counter(text) | |
| total = len(text) | |
| entropy = -sum((c / total) * math.log2(c / total) for c in freq.values()) | |
| score = 0.0 | |
| flag = "" | |
| if entropy < self.cfg["entropy_min"]: | |
| score = 0.5 | |
| flag = "low_entropy_repetitive" | |
| elif entropy > self.cfg["entropy_max"]: | |
| score = 0.6 | |
| flag = "high_entropy_possibly_encoded" | |
| return score, flag, {"entropy": round(entropy, 3)} | |
| def _check_symbol_density(self, text: str) -> tuple[float, str, dict]: | |
| if not text: | |
| return 0.0, "", {} | |
| non_alnum = sum(1 for c in text if not c.isalnum() and not c.isspace()) | |
| density = non_alnum / len(text) | |
| score = 0.0 | |
| flag = "" | |
| if density > self.cfg["symbol_density_max"]: | |
| score = min(density, 1.0) | |
| flag = "high_symbol_density" | |
| return score, flag, {"symbol_density": round(density, 3)} | |
| def _check_encoding_obfuscation(self, text: str) -> tuple[float, str, dict]: | |
| score = 0.0 | |
| flags = [] | |
| details = {} | |
| # Unicode escape sequences | |
| esc_matches = _UNICODE_ESC_RE.findall(text) | |
| if len(esc_matches) >= self.cfg["unicode_escape_threshold"]: | |
| score += 0.5 | |
| flags.append("unicode_escape_sequences") | |
| details["unicode_escapes"] = len(esc_matches) | |
| # Base64-like blobs | |
| b64_matches = _BASE64_RE.findall(text) | |
| if b64_matches: | |
| score += 0.4 | |
| flags.append("base64_like_content") | |
| details["base64_blocks"] = len(b64_matches) | |
| # Long hex strings | |
| hex_matches = _HEX_RE.findall(text) | |
| if hex_matches: | |
| score += 0.3 | |
| flags.append("hex_encoded_content") | |
| details["hex_blocks"] = len(hex_matches) | |
| return min(score, 1.0), "|".join(flags), details | |
| def _check_homoglyphs(self, text: str) -> tuple[float, str, dict]: | |
| count = sum(1 for ch in text if ch in _HOMOGLYPH_MAP) | |
| score = 0.0 | |
| flag = "" | |
| if count >= self.cfg["homoglyph_threshold"]: | |
| score = min(count / 20, 1.0) | |
| flag = "homoglyph_substitution" | |
| return score, flag, {"homoglyph_count": count} | |
| def _check_unicode_normalization(self, text: str) -> tuple[float, str, dict]: | |
| """Detect invisible / zero-width / direction-override characters.""" | |
| bad_categories = {"Cf", "Cs", "Co"} # format, surrogate, private-use | |
| bad_chars = [c for c in text if unicodedata.category(c) in bad_categories] | |
| score = 0.0 | |
| flag = "" | |
| if len(bad_chars) > 2: | |
| score = min(len(bad_chars) / 10, 1.0) | |
| flag = "invisible_unicode_chars" | |
| return score, flag, {"invisible_char_count": len(bad_chars)} | |
| def _check_embedding_outlier(self, text: str) -> tuple[float, str, dict]: | |
| if not self.use_embeddings or self._embedder is None or self._normal_centroid is None: | |
| return 0.0, "", {} | |
| try: | |
| import numpy as np | |
| emb = self._embedder.encode(text, convert_to_numpy=True, normalize_embeddings=True) | |
| similarity = float(emb @ self._normal_centroid) | |
| distance = 1.0 - similarity # 0 = identical to normal, 1 = orthogonal | |
| score = max(0.0, (distance - 0.3) / 0.7) # linear rescale [0.3, 1.0] → [0, 1] | |
| flag = "embedding_outlier" if score > 0.3 else "" | |
| return score, flag, {"centroid_distance": round(distance, 4)} | |
| except Exception as exc: | |
| logger.debug("Embedding outlier check failed: %s", exc) | |
| return 0.0, "", {} | |
| # ------------------------------------------------------------------ | |
| # Aggregation | |
| # ------------------------------------------------------------------ | |
| def detect(self, text: str) -> AdversarialResult: | |
| """ | |
| Run all detection layers and return an AdversarialResult. | |
| Parameters | |
| ---------- | |
| text : str | |
| Raw user prompt. | |
| """ | |
| t0 = time.perf_counter() | |
| checks = [ | |
| self._check_length(text), | |
| self._check_repetition(text), | |
| self._check_entropy(text), | |
| self._check_symbol_density(text), | |
| self._check_encoding_obfuscation(text), | |
| self._check_homoglyphs(text), | |
| self._check_unicode_normalization(text), | |
| self._check_embedding_outlier(text), | |
| ] | |
| aggregate_score = 0.0 | |
| all_flags: List[str] = [] | |
| all_details: dict = {} | |
| weights = [0.15, 0.20, 0.15, 0.10, 0.20, 0.10, 0.10, 0.20] # sum > 1 ok; normalised below | |
| weight_sum = sum(weights) | |
| for (score, flag, details), weight in zip(checks, weights): | |
| aggregate_score += score * weight | |
| if flag: | |
| all_flags.extend(flag.split("|")) | |
| all_details.update(details) | |
| risk_score = min(aggregate_score / weight_sum, 1.0) | |
| is_adversarial = risk_score >= self.threshold | |
| latency = (time.perf_counter() - t0) * 1000 | |
| result = AdversarialResult( | |
| is_adversarial=is_adversarial, | |
| risk_score=risk_score, | |
| flags=list(filter(None, all_flags)), | |
| details=all_details, | |
| latency_ms=latency, | |
| ) | |
| if is_adversarial: | |
| logger.warning("Adversarial input detected | score=%.3f flags=%s", risk_score, all_flags) | |
| return result | |
| def is_safe(self, text: str) -> bool: | |
| return not self.detect(text).is_adversarial | |