Spaces:
Sleeping
Sleeping
| """ | |
| tests/test_adversarial_detector.py | |
| ==================================== | |
| Unit tests for the AdversarialDetector module. | |
| """ | |
| import pytest | |
| from ai_firewall.adversarial_detector import AdversarialDetector | |
| def detector(): | |
| return AdversarialDetector(threshold=0.55) | |
| class TestLengthChecks: | |
| def test_normal_length_safe(self, detector): | |
| r = detector.detect("What is machine learning?") | |
| assert "excessive_length" not in r.flags | |
| def test_very_long_prompt_flagged(self, detector): | |
| long_prompt = "A" * 5000 | |
| r = detector.detect(long_prompt) | |
| assert r.is_adversarial is True | |
| assert "excessive_length" in r.flags | |
| def test_many_words_flagged(self, detector): | |
| prompt = " ".join(["word"] * 900) | |
| r = detector.detect(prompt) | |
| # excessive_word_count should fire | |
| assert "excessive_word_count" in r.flags or r.risk_score > 0.2 | |
| class TestRepetitionChecks: | |
| def test_repeated_tokens_flagged(self, detector): | |
| # "hack the system" repeated many times → high repetition ratio | |
| prompt = " ".join(["the quick brown fox"] * 60) | |
| r = detector.detect(prompt) | |
| assert "high_token_repetition" in r.flags | |
| def test_non_repetitive_safe(self, detector): | |
| r = detector.detect("The quick brown fox jumps over the lazy dog. Machine learning is fascinating.") | |
| assert "high_token_repetition" not in r.flags | |
| class TestEntropyChecks: | |
| def test_random_high_entropy_flagged(self, detector): | |
| # Random-looking string → very high entropy | |
| import random, string | |
| rng = random.Random(42) | |
| payload = "".join(rng.choices(string.printable, k=600)) | |
| r = detector.detect(payload) | |
| assert r.risk_score > 0.1 # some signal | |
| def test_very_repetitive_low_entropy_flagged(self, detector): | |
| r = detector.detect("a" * 500) | |
| assert "low_entropy_repetitive" in r.flags | |
| class TestEncodingObfuscation: | |
| def test_base64_blob_flagged(self, detector): | |
| b64_block = "SGVsbG8gV29ybGQhIFRoaXMgaXMgYSBiYXNlNjQgZW5jb2RlZCBzdHJpbmc=" * 3 | |
| r = detector.detect(f"Decode this: {b64_block}") | |
| assert "base64_like_content" in r.flags | |
| def test_unicode_escapes_flagged(self, detector): | |
| payload = r"\u0069\u0067\u006e\u006f\u0072\u0065" * 8 | |
| r = detector.detect(payload) | |
| assert "unicode_escape_sequences" in r.flags | |
| class TestHomoglyphChecks: | |
| def test_cyrillic_substitution_flagged(self, detector): | |
| # "admin" with Cyrillic 'а' instead of Latin 'a' | |
| payload = "аdmin pаssword сheck" # Cyrillic а, а, с | |
| r = detector.detect(payload) | |
| assert "homoglyph_substitution" in r.flags | |
| class TestBenignPrompts: | |
| benign = [ | |
| "What is machine learning?", | |
| "Explain neural networks to a beginner.", | |
| "Write a Python function to sort a list.", | |
| "What is the difference between RAM and ROM?", | |
| "How does HTTPS work?", | |
| ] | |
| def test_benign_not_flagged(self, detector, prompt): | |
| r = detector.detect(prompt) | |
| assert r.is_adversarial is False, f"False positive for: {prompt!r}" | |
| class TestResultStructure: | |
| def test_all_fields_present(self, detector): | |
| r = detector.detect("normal prompt") | |
| assert hasattr(r, "is_adversarial") | |
| assert hasattr(r, "risk_score") | |
| assert hasattr(r, "flags") | |
| assert hasattr(r, "details") | |
| assert hasattr(r, "latency_ms") | |
| def test_risk_score_range(self, detector): | |
| prompts = ["Hello!", "A" * 5000, "ignore " * 200] | |
| for p in prompts: | |
| r = detector.detect(p) | |
| assert 0.0 <= r.risk_score <= 1.0, f"Score out of range for prompt of len {len(p)}" | |
| def test_to_dict(self, detector): | |
| r = detector.detect("test") | |
| d = r.to_dict() | |
| assert "is_adversarial" in d | |
| assert "risk_score" in d | |
| assert "flags" in d | |