File size: 5,450 Bytes
e363dc8
 
 
 
 
07d8601
e363dc8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3688dd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
07d8601
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
"""Tests for structured audit logging."""

from __future__ import annotations

import json
import os

from agent_bench.security.audit_logger import AuditLogger


class TestAuditLogger:
    def test_log_creates_file(self, tmp_path):
        log_path = tmp_path / "audit.jsonl"
        logger = AuditLogger(path=str(log_path))
        logger.log({"request_id": "test-1", "endpoint": "/ask"})
        assert log_path.exists()

    def test_log_appends_jsonl(self, tmp_path):
        log_path = tmp_path / "audit.jsonl"
        logger = AuditLogger(path=str(log_path))
        logger.log({"request_id": "r1"})
        logger.log({"request_id": "r2"})
        lines = log_path.read_text().strip().split("\n")
        assert len(lines) == 2
        assert json.loads(lines[0])["request_id"] == "r1"
        assert json.loads(lines[1])["request_id"] == "r2"

    def test_log_adds_timestamp(self, tmp_path):
        log_path = tmp_path / "audit.jsonl"
        logger = AuditLogger(path=str(log_path))
        logger.log({"request_id": "r1"})
        record = json.loads(log_path.read_text().strip())
        assert "timestamp" in record

    def test_hash_ip(self):
        logger = AuditLogger(path="/dev/null")
        hashed = logger.hash_ip("192.168.1.1")
        # Deterministic
        assert hashed == logger.hash_ip("192.168.1.1")
        # Not the raw IP
        assert "192.168.1.1" not in hashed
        # SHA-256 hex = 64 chars
        assert len(hashed) == 64

    def test_hash_ip_different_inputs(self):
        logger = AuditLogger(path="/dev/null")
        assert logger.hash_ip("10.0.0.1") != logger.hash_ip("10.0.0.2")

    def test_log_rotation(self, tmp_path):
        log_path = tmp_path / "audit.jsonl"
        # 1 byte max size to force rotation on second write
        logger = AuditLogger(path=str(log_path), max_size_bytes=1, rotate=True)
        logger.log({"request_id": "r1"})
        logger.log({"request_id": "r2"})
        # Original file should still exist with latest record
        assert log_path.exists()
        # Rotated file should exist
        rotated = list(tmp_path.glob("audit.jsonl.*"))
        assert len(rotated) >= 1

    def test_no_rotation_when_disabled(self, tmp_path):
        log_path = tmp_path / "audit.jsonl"
        logger = AuditLogger(path=str(log_path), max_size_bytes=1, rotate=False)
        logger.log({"request_id": "r1"})
        logger.log({"request_id": "r2"})
        rotated = list(tmp_path.glob("audit.jsonl.*"))
        assert len(rotated) == 0

    def test_creates_parent_directories(self, tmp_path):
        log_path = tmp_path / "nested" / "dir" / "audit.jsonl"
        logger = AuditLogger(path=str(log_path))
        logger.log({"request_id": "r1"})
        assert log_path.exists()

    def test_multiple_rotations_no_data_loss(self, tmp_path):
        """Multiple rotations in the same second must not overwrite each other."""
        log_path = tmp_path / "audit.jsonl"
        logger = AuditLogger(path=str(log_path), max_size_bytes=1, rotate=True)
        logger.log({"request_id": "r1"})
        logger.log({"request_id": "r2"})
        logger.log({"request_id": "r3"})
        # All 3 records must survive: 2 in rotated files, 1 in active log
        rotated = list(tmp_path.glob("audit.jsonl.*"))
        assert len(rotated) == 2
        all_records = []
        for f in [log_path, *rotated]:
            for line in f.read_text().strip().split("\n"):
                all_records.append(json.loads(line)["request_id"])
        assert sorted(all_records) == ["r1", "r2", "r3"]

    def test_hash_ip_different_keys_produce_different_hashes(self):
        """Different HMAC keys produce different hashes for the same IP."""
        logger_a = AuditLogger(path="/dev/null", hmac_key="key-a")
        logger_b = AuditLogger(path="/dev/null", hmac_key="key-b")
        assert logger_a.hash_ip("192.168.1.1") != logger_b.hash_ip("192.168.1.1")

    def test_hash_ip_stable_with_same_key(self):
        """Same HMAC key produces consistent hashes across instances."""
        logger_a = AuditLogger(path="/dev/null", hmac_key="stable-key")
        logger_b = AuditLogger(path="/dev/null", hmac_key="stable-key")
        assert logger_a.hash_ip("10.0.0.1") == logger_b.hash_ip("10.0.0.1")

    def test_multi_instance_rotation_no_data_loss(self, tmp_path):
        """Two logger instances rotating the same file must not overwrite each other."""
        log_path = tmp_path / "audit.jsonl"
        logger_a = AuditLogger(path=str(log_path), max_size_bytes=1, rotate=True)
        logger_b = AuditLogger(path=str(log_path), max_size_bytes=1, rotate=True)
        logger_a.log({"request_id": "r1"})
        logger_b.log({"request_id": "r2"})
        logger_a.log({"request_id": "r3"})
        # All 3 records must survive across rotated files + active log
        all_records = []
        for f in tmp_path.glob("audit.jsonl*"):
            for line in f.read_text().strip().split("\n"):
                if line:
                    all_records.append(json.loads(line)["request_id"])
        assert sorted(all_records) == ["r1", "r2", "r3"]

    def test_no_hmac_key_logs_warning(self, tmp_path, capsys):
        """Default-constructed logger warns about non-stable IP hashing."""
        os.environ.pop("AUDIT_HMAC_KEY", None)
        AuditLogger(path=str(tmp_path / "audit.jsonl"))
        captured = capsys.readouterr()
        assert "audit_hmac_key_missing" in captured.out