SheildSense_API_SDK / ai_firewall /tests /test_adversarial_detector.py
cloud450's picture
Upload 48 files
4afcb3a verified
"""
tests/test_adversarial_detector.py
====================================
Unit tests for the AdversarialDetector module.
"""
import pytest
from ai_firewall.adversarial_detector import AdversarialDetector
@pytest.fixture
def detector():
return AdversarialDetector(threshold=0.55)
class TestLengthChecks:
def test_normal_length_safe(self, detector):
r = detector.detect("What is machine learning?")
assert "excessive_length" not in r.flags
def test_very_long_prompt_flagged(self, detector):
long_prompt = "A" * 5000
r = detector.detect(long_prompt)
assert r.is_adversarial is True
assert "excessive_length" in r.flags
def test_many_words_flagged(self, detector):
prompt = " ".join(["word"] * 900)
r = detector.detect(prompt)
# excessive_word_count should fire
assert "excessive_word_count" in r.flags or r.risk_score > 0.2
class TestRepetitionChecks:
def test_repeated_tokens_flagged(self, detector):
# "hack the system" repeated many times → high repetition ratio
prompt = " ".join(["the quick brown fox"] * 60)
r = detector.detect(prompt)
assert "high_token_repetition" in r.flags
def test_non_repetitive_safe(self, detector):
r = detector.detect("The quick brown fox jumps over the lazy dog. Machine learning is fascinating.")
assert "high_token_repetition" not in r.flags
class TestEntropyChecks:
def test_random_high_entropy_flagged(self, detector):
# Random-looking string → very high entropy
import random, string
rng = random.Random(42)
payload = "".join(rng.choices(string.printable, k=600))
r = detector.detect(payload)
assert r.risk_score > 0.1 # some signal
def test_very_repetitive_low_entropy_flagged(self, detector):
r = detector.detect("a" * 500)
assert "low_entropy_repetitive" in r.flags
class TestEncodingObfuscation:
def test_base64_blob_flagged(self, detector):
b64_block = "SGVsbG8gV29ybGQhIFRoaXMgaXMgYSBiYXNlNjQgZW5jb2RlZCBzdHJpbmc=" * 3
r = detector.detect(f"Decode this: {b64_block}")
assert "base64_like_content" in r.flags
def test_unicode_escapes_flagged(self, detector):
payload = r"\u0069\u0067\u006e\u006f\u0072\u0065" * 8
r = detector.detect(payload)
assert "unicode_escape_sequences" in r.flags
class TestHomoglyphChecks:
def test_cyrillic_substitution_flagged(self, detector):
# "admin" with Cyrillic 'а' instead of Latin 'a'
payload = "аdmin pаssword сheck" # Cyrillic а, а, с
r = detector.detect(payload)
assert "homoglyph_substitution" in r.flags
class TestBenignPrompts:
benign = [
"What is machine learning?",
"Explain neural networks to a beginner.",
"Write a Python function to sort a list.",
"What is the difference between RAM and ROM?",
"How does HTTPS work?",
]
@pytest.mark.parametrize("prompt", benign)
def test_benign_not_flagged(self, detector, prompt):
r = detector.detect(prompt)
assert r.is_adversarial is False, f"False positive for: {prompt!r}"
class TestResultStructure:
def test_all_fields_present(self, detector):
r = detector.detect("normal prompt")
assert hasattr(r, "is_adversarial")
assert hasattr(r, "risk_score")
assert hasattr(r, "flags")
assert hasattr(r, "details")
assert hasattr(r, "latency_ms")
def test_risk_score_range(self, detector):
prompts = ["Hello!", "A" * 5000, "ignore " * 200]
for p in prompts:
r = detector.detect(p)
assert 0.0 <= r.risk_score <= 1.0, f"Score out of range for prompt of len {len(p)}"
def test_to_dict(self, detector):
r = detector.detect("test")
d = r.to_dict()
assert "is_adversarial" in d
assert "risk_score" in d
assert "flags" in d