| """Tests for audit log list, detail, filters, and post-query persistence.""" |
|
|
| import asyncio |
| from unittest.mock import AsyncMock |
| from uuid import uuid4 |
|
|
| import pytest |
| from fastapi.testclient import TestClient |
|
|
| from api.config import Settings |
| from api.main import app |
| from models.responses import SourceCitation |
| from rag.retriever import RetrievedChunk |
| from storage.audit_store import persist_query_audit |
|
|
|
|
| def _seed_audit(settings: Settings, question: str = "What are key risks?", user_id: str = "analyst_001") -> str: |
| query_id = str(uuid4()) |
| asyncio.run( |
| persist_query_audit( |
| settings.audit_db_path, |
| query_id=query_id, |
| action="query", |
| user_id=user_id, |
| question=question, |
| collection_name="default", |
| answer="Grounded answer text for audit trail.", |
| sources=[ |
| SourceCitation( |
| document_name="report.pdf", |
| page_number=3, |
| chunk_text="Risk disclosure excerpt.", |
| relevance_score=0.9, |
| ) |
| ], |
| model_used="ollama:llama3.1:8b", |
| tokens_used=120, |
| response_time_ms=50, |
| kind="ask", |
| ) |
| ) |
| return query_id |
|
|
|
|
| def test_audit_logs_and_detail_success(client, settings): |
| query_id = _seed_audit(settings) |
|
|
| list_response = client.get("/audit/logs?limit=10&offset=0") |
| assert list_response.status_code == 200 |
| body = list_response.json() |
| assert "logs" in body |
| assert body["total"] >= 1 |
| assert any(entry["query_id"] == query_id for entry in body["logs"]) |
|
|
| detail_response = client.get(f"/audit/logs/{query_id}") |
| assert detail_response.status_code == 200 |
| detail = detail_response.json() |
| assert detail["query_id"] == query_id |
| assert detail["question"] == "What are key risks?" |
| assert detail["full_answer"] == "Grounded answer text for audit trail." |
| assert len(detail["sources"]) == 1 |
| assert detail["sources"][0]["document_name"] == "report.pdf" |
|
|
|
|
| def test_audit_logs_filter_by_user_id(client, settings): |
| q1 = _seed_audit(settings, question="Q one", user_id="user_a") |
| _seed_audit(settings, question="Q two", user_id="user_b") |
|
|
| r = client.get("/audit/logs", params={"user_id": "user_a", "limit": 50, "offset": 0}) |
| assert r.status_code == 200 |
| body = r.json() |
| ids = {e["query_id"] for e in body["logs"]} |
| assert q1 in ids |
| assert all(e["user_id"] == "user_a" for e in body["logs"]) |
|
|
|
|
| def test_audit_logs_filter_by_from_date(client, settings): |
| query_id = str(uuid4()) |
| asyncio.run( |
| persist_query_audit( |
| settings.audit_db_path, |
| query_id=query_id, |
| action="query", |
| user_id="u", |
| question="Future dated row", |
| collection_name="default", |
| answer="A", |
| sources=[], |
| model_used="m", |
| tokens_used=0, |
| response_time_ms=1, |
| kind="ask", |
| ) |
| ) |
| r = client.get("/audit/logs", params={"from_date": "2099-01-01T00:00:00Z", "limit": 50, "offset": 0}) |
| assert r.status_code == 200 |
| body = r.json() |
| assert query_id not in {e["query_id"] for e in body["logs"]} |
|
|
|
|
| def test_audit_logs_filter_by_to_date(client, settings): |
| """Spec: date filtering on /audit/logs (upper bound).""" |
| query_id = str(uuid4()) |
| asyncio.run( |
| persist_query_audit( |
| settings.audit_db_path, |
| query_id=query_id, |
| action="query", |
| user_id="u", |
| question="Recent row", |
| collection_name="default", |
| answer="B", |
| sources=[], |
| model_used="m", |
| tokens_used=0, |
| response_time_ms=1, |
| kind="ask", |
| ) |
| ) |
| r = client.get("/audit/logs", params={"to_date": "2000-01-01T00:00:00Z", "limit": 50, "offset": 0}) |
| assert r.status_code == 200 |
| body = r.json() |
| assert query_id not in {e["query_id"] for e in body["logs"]} |
|
|
|
|
| def test_ask_is_logged_after_query_ask(client, monkeypatch): |
| """Spec: ask is logged after POST /query/ask.""" |
| chunks = [ |
| RetrievedChunk( |
| text="Audit trail test chunk.", |
| score=0.9, |
| source="audit-test.txt", |
| page=1, |
| chunk_index=0, |
| ) |
| ] |
| monkeypatch.setattr("api.routes.query.create_embedding_function", lambda: object()) |
| monkeypatch.setattr("api.routes.query.get_vector_store", lambda **_: object()) |
| monkeypatch.setattr("api.routes.query.retrieve_chunks", lambda *_: chunks) |
| monkeypatch.setattr( |
| "api.routes.query.answer_with_grounding", |
| lambda *_: ("Answer stored in audit.", 11), |
| ) |
|
|
| ask = client.post( |
| "/query/ask", |
| json={ |
| "question": "What should appear in the audit log?", |
| "collection_name": "default", |
| "user_id": "audit_user", |
| }, |
| ) |
| assert ask.status_code == 200 |
| query_id = ask.json()["query_id"] |
|
|
| detail = client.get(f"/audit/logs/{query_id}") |
| assert detail.status_code == 200 |
| body = detail.json() |
| assert body["user_id"] == "audit_user" |
| assert body["full_answer"] == "Answer stored in audit." |
| assert body["question"] == "What should appear in the audit log?" |
|
|
|
|
| def test_summarise_is_logged_after_query_summarise(client, monkeypatch): |
| """Spec: summarise is logged after POST /query/summarise.""" |
| chunks = [ |
| RetrievedChunk( |
| text="Summary source chunk.", |
| score=0.85, |
| source="summary.md", |
| page=2, |
| chunk_index=0, |
| ) |
| ] |
| monkeypatch.setattr("api.routes.query.create_embedding_function", lambda: object()) |
| monkeypatch.setattr("api.routes.query.get_vector_store", lambda **_: object()) |
| monkeypatch.setattr("api.routes.query.retrieve_chunks", lambda *_: chunks) |
| monkeypatch.setattr( |
| "api.routes.query.summarise_with_grounding", |
| lambda *_, **__: ("Collection summary for audit.", 7), |
| ) |
| monkeypatch.setattr("api.routes.query.collection_document_count", lambda *_: 2) |
|
|
| summarise = client.post( |
| "/query/summarise", |
| json={"collection_name": "default", "focus": "key themes", "user_id": "sum_user"}, |
| ) |
| assert summarise.status_code == 200 |
| query_id = summarise.json()["query_id"] |
|
|
| detail = client.get(f"/audit/logs/{query_id}") |
| assert detail.status_code == 200 |
| assert detail.json()["full_answer"] == "Collection summary for audit." |
| assert detail.json()["user_id"] == "sum_user" |
|
|
|
|
| def test_audit_logs_validation_error_for_bad_limit(client): |
| response = client.get("/audit/logs?limit=0&offset=0") |
| assert response.status_code == 422 |
|
|
|
|
| def test_audit_detail_not_found(client): |
| response = client.get("/audit/logs/does-not-exist") |
| assert response.status_code == 404 |
| assert "not found" in response.json()["detail"].lower() |
|
|
|
|
| def test_audit_logs_returns_500_on_store_failure(settings, monkeypatch): |
| monkeypatch.setattr("api.main.get_settings", lambda: settings) |
| monkeypatch.setattr("api.routes.audit.get_settings", lambda: settings) |
| monkeypatch.setattr( |
| "api.routes.audit.list_audit_events", |
| AsyncMock(side_effect=RuntimeError("audit store failure")), |
| ) |
| with TestClient(app, raise_server_exceptions=False) as test_client: |
| response = test_client.get("/audit/logs") |
|
|
| assert response.status_code == 500 |
|
|