SheildSense_API_SDK / ai_firewall /adversarial_detector.py
cloud450's picture
Upload 48 files
4afcb3a verified
"""
adversarial_detector.py
========================
Detects adversarial / anomalous inputs that may be crafted to manipulate
AI models or evade safety filters.
Detection layers (all zero-dependency except the optional embedding layer):
1. Token-length analysis — unusually long or repetitive prompts
2. Character distribution — abnormal char class ratios (unicode tricks, homoglyphs)
3. Repetition detection — token/n-gram flooding
4. Encoding obfuscation — base64 blobs, hex strings, ROT-13 traces
5. Statistical anomaly — entropy, symbol density, whitespace abuse
6. Embedding outlier — cosine distance from "normal" centroid (optional)
"""
from __future__ import annotations
import re
import math
import time
import unicodedata
import logging
from collections import Counter
from dataclasses import dataclass, field
from typing import List, Optional
logger = logging.getLogger("ai_firewall.adversarial_detector")
# ---------------------------------------------------------------------------
# Config defaults (tunable without subclassing)
# ---------------------------------------------------------------------------
DEFAULT_CONFIG = {
"max_token_length": 4096, # chars (rough token proxy)
"max_word_count": 800,
"max_line_count": 200,
"repetition_threshold": 0.45, # fraction of repeated trigrams → adversarial
"entropy_min": 2.5, # too-low entropy = repetitive junk
"entropy_max": 5.8, # too-high entropy = encoded/random content
"symbol_density_max": 0.35, # fraction of non-alphanumeric chars
"unicode_escape_threshold": 5, # count of \uXXXX / \xXX sequences
"base64_min_length": 40, # minimum length of candidate b64 blocks
"homoglyph_threshold": 3, # count of confusable lookalike chars
}
# Homoglyph mapping (Cyrillic / Greek / other confusable lookalikes for latin)
_HOMOGLYPH_MAP = {
"а": "a", "е": "e", "і": "i", "о": "o", "р": "p", "с": "c",
"х": "x", "у": "y", "ѕ": "s", "ј": "j", "ԁ": "d", "ɡ": "g",
"ʜ": "h", "ᴛ": "t", "ᴡ": "w", "ᴍ": "m", "ᴋ": "k",
"α": "a", "ε": "e", "ο": "o", "ρ": "p", "ν": "v", "κ": "k",
}
_BASE64_RE = re.compile(r"(?:[A-Za-z0-9+/]{4}){10,}(?:[A-Za-z0-9+/]{2}==|[A-Za-z0-9+/]{3}=)?")
_HEX_RE = re.compile(r"(?:0x)?[0-9a-fA-F]{16,}")
_UNICODE_ESC_RE = re.compile(r"(\\u[0-9a-fA-F]{4}|\\x[0-9a-fA-F]{2}|%[0-9a-fA-F]{2})")
@dataclass
class AdversarialResult:
is_adversarial: bool
risk_score: float # 0.0 – 1.0
flags: List[str] = field(default_factory=list)
details: dict = field(default_factory=dict)
latency_ms: float = 0.0
def to_dict(self) -> dict:
return {
"is_adversarial": self.is_adversarial,
"risk_score": round(self.risk_score, 4),
"flags": self.flags,
"details": self.details,
"latency_ms": round(self.latency_ms, 2),
}
class AdversarialDetector:
"""
Stateless adversarial input detector.
A prompt is considered adversarial if its aggregate risk score
exceeds `threshold` (default 0.55).
Parameters
----------
threshold : float
Risk score above which input is flagged.
config : dict, optional
Override any key from DEFAULT_CONFIG.
use_embeddings : bool
Enable embedding-outlier detection (requires sentence-transformers).
embedding_model : str
Model name for the embedding layer.
"""
def __init__(
self,
threshold: float = 0.55,
config: Optional[dict] = None,
use_embeddings: bool = False,
embedding_model: str = "all-MiniLM-L6-v2",
) -> None:
self.threshold = threshold
self.cfg = {**DEFAULT_CONFIG, **(config or {})}
self.use_embeddings = use_embeddings
self._embedder = None
self._normal_centroid = None # set via `fit_normal_distribution`
if use_embeddings:
self._load_embedder(embedding_model)
# ------------------------------------------------------------------
# Embedding layer
# ------------------------------------------------------------------
def _load_embedder(self, model_name: str) -> None:
try:
from sentence_transformers import SentenceTransformer
import numpy as np
self._embedder = SentenceTransformer(model_name)
logger.info("Adversarial embedding layer loaded: %s", model_name)
except ImportError:
logger.warning("sentence-transformers not installed — embedding outlier layer disabled.")
self.use_embeddings = False
def fit_normal_distribution(self, normal_prompts: List[str]) -> None:
"""
Compute the centroid of embedding vectors for a set of known-good
prompts. Call this once at startup with representative benign prompts.
"""
if not self.use_embeddings or self._embedder is None:
return
import numpy as np
embeddings = self._embedder.encode(normal_prompts, convert_to_numpy=True, normalize_embeddings=True)
self._normal_centroid = embeddings.mean(axis=0)
self._normal_centroid /= np.linalg.norm(self._normal_centroid)
logger.info("Normal centroid computed from %d prompts.", len(normal_prompts))
# ------------------------------------------------------------------
# Individual checks
# ------------------------------------------------------------------
def _check_length(self, text: str) -> tuple[float, str, dict]:
char_len = len(text)
word_count = len(text.split())
line_count = text.count("\n")
score = 0.0
details, flags = {}, []
if char_len > self.cfg["max_token_length"]:
score += 0.4
flags.append("excessive_length")
if word_count > self.cfg["max_word_count"]:
score += 0.25
flags.append("excessive_word_count")
if line_count > self.cfg["max_line_count"]:
score += 0.2
flags.append("excessive_line_count")
details = {"char_len": char_len, "word_count": word_count, "line_count": line_count}
return min(score, 1.0), "|".join(flags), details
def _check_repetition(self, text: str) -> tuple[float, str, dict]:
words = text.lower().split()
if len(words) < 6:
return 0.0, "", {}
trigrams = [tuple(words[i:i+3]) for i in range(len(words) - 2)]
if not trigrams:
return 0.0, "", {}
total = len(trigrams)
unique = len(set(trigrams))
repetition_ratio = 1.0 - (unique / total)
score = 0.0
flag = ""
if repetition_ratio >= self.cfg["repetition_threshold"]:
score = min(repetition_ratio, 1.0)
flag = "high_token_repetition"
return score, flag, {"repetition_ratio": round(repetition_ratio, 3)}
def _check_entropy(self, text: str) -> tuple[float, str, dict]:
if not text:
return 0.0, "", {}
freq = Counter(text)
total = len(text)
entropy = -sum((c / total) * math.log2(c / total) for c in freq.values())
score = 0.0
flag = ""
if entropy < self.cfg["entropy_min"]:
score = 0.5
flag = "low_entropy_repetitive"
elif entropy > self.cfg["entropy_max"]:
score = 0.6
flag = "high_entropy_possibly_encoded"
return score, flag, {"entropy": round(entropy, 3)}
def _check_symbol_density(self, text: str) -> tuple[float, str, dict]:
if not text:
return 0.0, "", {}
non_alnum = sum(1 for c in text if not c.isalnum() and not c.isspace())
density = non_alnum / len(text)
score = 0.0
flag = ""
if density > self.cfg["symbol_density_max"]:
score = min(density, 1.0)
flag = "high_symbol_density"
return score, flag, {"symbol_density": round(density, 3)}
def _check_encoding_obfuscation(self, text: str) -> tuple[float, str, dict]:
score = 0.0
flags = []
details = {}
# Unicode escape sequences
esc_matches = _UNICODE_ESC_RE.findall(text)
if len(esc_matches) >= self.cfg["unicode_escape_threshold"]:
score += 0.5
flags.append("unicode_escape_sequences")
details["unicode_escapes"] = len(esc_matches)
# Base64-like blobs
b64_matches = _BASE64_RE.findall(text)
if b64_matches:
score += 0.4
flags.append("base64_like_content")
details["base64_blocks"] = len(b64_matches)
# Long hex strings
hex_matches = _HEX_RE.findall(text)
if hex_matches:
score += 0.3
flags.append("hex_encoded_content")
details["hex_blocks"] = len(hex_matches)
return min(score, 1.0), "|".join(flags), details
def _check_homoglyphs(self, text: str) -> tuple[float, str, dict]:
count = sum(1 for ch in text if ch in _HOMOGLYPH_MAP)
score = 0.0
flag = ""
if count >= self.cfg["homoglyph_threshold"]:
score = min(count / 20, 1.0)
flag = "homoglyph_substitution"
return score, flag, {"homoglyph_count": count}
def _check_unicode_normalization(self, text: str) -> tuple[float, str, dict]:
"""Detect invisible / zero-width / direction-override characters."""
bad_categories = {"Cf", "Cs", "Co"} # format, surrogate, private-use
bad_chars = [c for c in text if unicodedata.category(c) in bad_categories]
score = 0.0
flag = ""
if len(bad_chars) > 2:
score = min(len(bad_chars) / 10, 1.0)
flag = "invisible_unicode_chars"
return score, flag, {"invisible_char_count": len(bad_chars)}
def _check_embedding_outlier(self, text: str) -> tuple[float, str, dict]:
if not self.use_embeddings or self._embedder is None or self._normal_centroid is None:
return 0.0, "", {}
try:
import numpy as np
emb = self._embedder.encode(text, convert_to_numpy=True, normalize_embeddings=True)
similarity = float(emb @ self._normal_centroid)
distance = 1.0 - similarity # 0 = identical to normal, 1 = orthogonal
score = max(0.0, (distance - 0.3) / 0.7) # linear rescale [0.3, 1.0] → [0, 1]
flag = "embedding_outlier" if score > 0.3 else ""
return score, flag, {"centroid_distance": round(distance, 4)}
except Exception as exc:
logger.debug("Embedding outlier check failed: %s", exc)
return 0.0, "", {}
# ------------------------------------------------------------------
# Aggregation
# ------------------------------------------------------------------
def detect(self, text: str) -> AdversarialResult:
"""
Run all detection layers and return an AdversarialResult.
Parameters
----------
text : str
Raw user prompt.
"""
t0 = time.perf_counter()
checks = [
self._check_length(text),
self._check_repetition(text),
self._check_entropy(text),
self._check_symbol_density(text),
self._check_encoding_obfuscation(text),
self._check_homoglyphs(text),
self._check_unicode_normalization(text),
self._check_embedding_outlier(text),
]
aggregate_score = 0.0
all_flags: List[str] = []
all_details: dict = {}
weights = [0.15, 0.20, 0.15, 0.10, 0.20, 0.10, 0.10, 0.20] # sum > 1 ok; normalised below
weight_sum = sum(weights)
for (score, flag, details), weight in zip(checks, weights):
aggregate_score += score * weight
if flag:
all_flags.extend(flag.split("|"))
all_details.update(details)
risk_score = min(aggregate_score / weight_sum, 1.0)
is_adversarial = risk_score >= self.threshold
latency = (time.perf_counter() - t0) * 1000
result = AdversarialResult(
is_adversarial=is_adversarial,
risk_score=risk_score,
flags=list(filter(None, all_flags)),
details=all_details,
latency_ms=latency,
)
if is_adversarial:
logger.warning("Adversarial input detected | score=%.3f flags=%s", risk_score, all_flags)
return result
def is_safe(self, text: str) -> bool:
return not self.detect(text).is_adversarial