Document-Audit-RAG / tests /test_query.py
Mayank Chugh
Enhance environment configuration and API documentation for Milestone 11
a32f9e3
from unittest.mock import AsyncMock
import pytest
from fastapi.testclient import TestClient
from api.config import Settings
from api.main import app
from rag.retriever import RetrievedChunk
def _test_settings(tmp_path) -> Settings:
return Settings(
llm_provider="ollama",
chroma_persist_directory=str(tmp_path / "chroma"),
audit_db_path=str(tmp_path / "audit.db"),
jobs_db_path=str(tmp_path / "jobs.db"),
top_k_results=3,
)
@pytest.fixture
def client(tmp_path, monkeypatch):
settings = _test_settings(tmp_path)
monkeypatch.setattr("api.main.get_settings", lambda: settings)
monkeypatch.setattr("api.routes.query.get_settings", lambda: settings)
with TestClient(app) as test_client:
yield test_client
def test_ask_returns_grounded_answer_with_sources(client, monkeypatch):
chunks = [
RetrievedChunk(
text="Audi has strategic EV expansion plans.",
score=0.92,
source="strategy.md",
page=1,
chunk_index=0,
)
]
monkeypatch.setattr("api.routes.query.create_embedding_function", lambda: object())
monkeypatch.setattr("api.routes.query.get_vector_store", lambda **_: object())
monkeypatch.setattr("api.routes.query.retrieve_chunks", lambda *_: chunks)
monkeypatch.setattr("api.routes.query.answer_with_grounding", lambda *_: ("Audi is expanding EV investment.", 42))
monkeypatch.setattr("api.routes.query.persist_query_audit", AsyncMock(return_value="evt-1"))
response = client.post(
"/query/ask",
json={
"question": "What is Audi doing in EV markets worldwide?",
"collection_name": "default",
"top_k": 3,
"user_id": "tester",
},
)
assert response.status_code == 200
body = response.json()
assert body["answer"] == "Audi is expanding EV investment."
assert "query_id" in body
assert body["question"].startswith("What is Audi")
assert len(body["sources"]) == 1
assert body["sources"][0]["document_name"] == "strategy.md"
assert body["sources"][0]["page_number"] == 1
assert body["tokens_used"] == 42
assert "response_time_ms" in body
assert "model_used" in body
def test_ask_respects_top_k_in_retrieve_call(client, monkeypatch):
captured: dict[str, object] = {}
def capture_retrieve(vs, question, k):
captured["k"] = k
return []
monkeypatch.setattr("api.routes.query.create_embedding_function", lambda: object())
monkeypatch.setattr("api.routes.query.get_vector_store", lambda **_: object())
monkeypatch.setattr("api.routes.query.retrieve_chunks", capture_retrieve)
monkeypatch.setattr("api.routes.query.answer_with_grounding", lambda *_: ("No match answer", 0))
monkeypatch.setattr("api.routes.query.persist_query_audit", AsyncMock())
response = client.post(
"/query/ask",
json={"question": "What is known about the topic here?", "collection_name": "default", "top_k": 7},
)
assert response.status_code == 200
assert captured.get("k") == 7
def test_ask_returns_422_for_invalid_payload(client):
response = client.post("/query/ask", json={"collection_name": "default"})
assert response.status_code == 422
def test_ask_returns_422_for_short_question(client):
response = client.post(
"/query/ask",
json={"question": "hi", "collection_name": "default"},
)
assert response.status_code == 422
def test_ask_returns_500_when_retrieval_fails(client, monkeypatch):
monkeypatch.setattr("api.routes.query.create_embedding_function", lambda: object())
monkeypatch.setattr("api.routes.query.get_vector_store", lambda **_: object())
monkeypatch.setattr("api.routes.query.retrieve_chunks", lambda *_: (_ for _ in ()).throw(RuntimeError("retrieval failed")))
response = client.post(
"/query/ask",
json={"question": "What happened in the documents?", "collection_name": "default"},
)
assert response.status_code == 500
assert "retrieval failed" in response.json()["detail"]
def test_summarise_returns_500_when_audit_persist_fails(client, monkeypatch):
chunks = [
RetrievedChunk(
text="Revenue and risks are discussed in the report.",
score=0.88,
source="report.txt",
page=None,
chunk_index=2,
)
]
monkeypatch.setattr("api.routes.query.create_embedding_function", lambda: object())
monkeypatch.setattr("api.routes.query.get_vector_store", lambda **_: object())
monkeypatch.setattr("api.routes.query.retrieve_chunks", lambda *_: chunks)
monkeypatch.setattr("api.routes.query.summarise_with_grounding", lambda *_, **__: ("Summary output", 10))
monkeypatch.setattr("api.routes.query.collection_document_count", lambda *_: 5)
monkeypatch.setattr(
"api.routes.query.persist_query_audit",
AsyncMock(side_effect=RuntimeError("audit write failed")),
)
response = client.post(
"/query/summarise",
json={"collection_name": "default", "focus": "summarise risks", "user_id": "u1"},
)
assert response.status_code == 500
assert "audit write failed" in response.json()["detail"]