File size: 7,491 Bytes
d44b33d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
"""Tests for audit log list, detail, filters, and post-query persistence."""

import asyncio
from unittest.mock import AsyncMock
from uuid import uuid4

import pytest
from fastapi.testclient import TestClient

from api.config import Settings
from api.main import app
from models.responses import SourceCitation
from rag.retriever import RetrievedChunk
from storage.audit_store import persist_query_audit


def _seed_audit(settings: Settings, question: str = "What are key risks?", user_id: str = "analyst_001") -> str:
    query_id = str(uuid4())
    asyncio.run(
        persist_query_audit(
            settings.audit_db_path,
            query_id=query_id,
            action="query",
            user_id=user_id,
            question=question,
            collection_name="default",
            answer="Grounded answer text for audit trail.",
            sources=[
                SourceCitation(
                    document_name="report.pdf",
                    page_number=3,
                    chunk_text="Risk disclosure excerpt.",
                    relevance_score=0.9,
                )
            ],
            model_used="ollama:llama3.1:8b",
            tokens_used=120,
            response_time_ms=50,
            kind="ask",
        )
    )
    return query_id


def test_audit_logs_and_detail_success(client, settings):
    query_id = _seed_audit(settings)

    list_response = client.get("/audit/logs?limit=10&offset=0")
    assert list_response.status_code == 200
    body = list_response.json()
    assert "logs" in body
    assert body["total"] >= 1
    assert any(entry["query_id"] == query_id for entry in body["logs"])

    detail_response = client.get(f"/audit/logs/{query_id}")
    assert detail_response.status_code == 200
    detail = detail_response.json()
    assert detail["query_id"] == query_id
    assert detail["question"] == "What are key risks?"
    assert detail["full_answer"] == "Grounded answer text for audit trail."
    assert len(detail["sources"]) == 1
    assert detail["sources"][0]["document_name"] == "report.pdf"


def test_audit_logs_filter_by_user_id(client, settings):
    q1 = _seed_audit(settings, question="Q one", user_id="user_a")
    _seed_audit(settings, question="Q two", user_id="user_b")

    r = client.get("/audit/logs", params={"user_id": "user_a", "limit": 50, "offset": 0})
    assert r.status_code == 200
    body = r.json()
    ids = {e["query_id"] for e in body["logs"]}
    assert q1 in ids
    assert all(e["user_id"] == "user_a" for e in body["logs"])


def test_audit_logs_filter_by_from_date(client, settings):
    query_id = str(uuid4())
    asyncio.run(
        persist_query_audit(
            settings.audit_db_path,
            query_id=query_id,
            action="query",
            user_id="u",
            question="Future dated row",
            collection_name="default",
            answer="A",
            sources=[],
            model_used="m",
            tokens_used=0,
            response_time_ms=1,
            kind="ask",
        )
    )
    r = client.get("/audit/logs", params={"from_date": "2099-01-01T00:00:00Z", "limit": 50, "offset": 0})
    assert r.status_code == 200
    body = r.json()
    assert query_id not in {e["query_id"] for e in body["logs"]}


def test_audit_logs_filter_by_to_date(client, settings):
    """Spec: date filtering on /audit/logs (upper bound)."""
    query_id = str(uuid4())
    asyncio.run(
        persist_query_audit(
            settings.audit_db_path,
            query_id=query_id,
            action="query",
            user_id="u",
            question="Recent row",
            collection_name="default",
            answer="B",
            sources=[],
            model_used="m",
            tokens_used=0,
            response_time_ms=1,
            kind="ask",
        )
    )
    r = client.get("/audit/logs", params={"to_date": "2000-01-01T00:00:00Z", "limit": 50, "offset": 0})
    assert r.status_code == 200
    body = r.json()
    assert query_id not in {e["query_id"] for e in body["logs"]}


def test_ask_is_logged_after_query_ask(client, monkeypatch):
    """Spec: ask is logged after POST /query/ask."""
    chunks = [
        RetrievedChunk(
            text="Audit trail test chunk.",
            score=0.9,
            source="audit-test.txt",
            page=1,
            chunk_index=0,
        )
    ]
    monkeypatch.setattr("api.routes.query.create_embedding_function", lambda: object())
    monkeypatch.setattr("api.routes.query.get_vector_store", lambda **_: object())
    monkeypatch.setattr("api.routes.query.retrieve_chunks", lambda *_: chunks)
    monkeypatch.setattr(
        "api.routes.query.answer_with_grounding",
        lambda *_: ("Answer stored in audit.", 11),
    )

    ask = client.post(
        "/query/ask",
        json={
            "question": "What should appear in the audit log?",
            "collection_name": "default",
            "user_id": "audit_user",
        },
    )
    assert ask.status_code == 200
    query_id = ask.json()["query_id"]

    detail = client.get(f"/audit/logs/{query_id}")
    assert detail.status_code == 200
    body = detail.json()
    assert body["user_id"] == "audit_user"
    assert body["full_answer"] == "Answer stored in audit."
    assert body["question"] == "What should appear in the audit log?"


def test_summarise_is_logged_after_query_summarise(client, monkeypatch):
    """Spec: summarise is logged after POST /query/summarise."""
    chunks = [
        RetrievedChunk(
            text="Summary source chunk.",
            score=0.85,
            source="summary.md",
            page=2,
            chunk_index=0,
        )
    ]
    monkeypatch.setattr("api.routes.query.create_embedding_function", lambda: object())
    monkeypatch.setattr("api.routes.query.get_vector_store", lambda **_: object())
    monkeypatch.setattr("api.routes.query.retrieve_chunks", lambda *_: chunks)
    monkeypatch.setattr(
        "api.routes.query.summarise_with_grounding",
        lambda *_, **__: ("Collection summary for audit.", 7),
    )
    monkeypatch.setattr("api.routes.query.collection_document_count", lambda *_: 2)

    summarise = client.post(
        "/query/summarise",
        json={"collection_name": "default", "focus": "key themes", "user_id": "sum_user"},
    )
    assert summarise.status_code == 200
    query_id = summarise.json()["query_id"]

    detail = client.get(f"/audit/logs/{query_id}")
    assert detail.status_code == 200
    assert detail.json()["full_answer"] == "Collection summary for audit."
    assert detail.json()["user_id"] == "sum_user"


def test_audit_logs_validation_error_for_bad_limit(client):
    response = client.get("/audit/logs?limit=0&offset=0")
    assert response.status_code == 422


def test_audit_detail_not_found(client):
    response = client.get("/audit/logs/does-not-exist")
    assert response.status_code == 404
    assert "not found" in response.json()["detail"].lower()


def test_audit_logs_returns_500_on_store_failure(settings, monkeypatch):
    monkeypatch.setattr("api.main.get_settings", lambda: settings)
    monkeypatch.setattr("api.routes.audit.get_settings", lambda: settings)
    monkeypatch.setattr(
        "api.routes.audit.list_audit_events",
        AsyncMock(side_effect=RuntimeError("audit store failure")),
    )
    with TestClient(app, raise_server_exceptions=False) as test_client:
        response = test_client.get("/audit/logs")

    assert response.status_code == 500