from unittest.mock import AsyncMock import pytest from fastapi.testclient import TestClient from api.config import Settings from api.main import app from rag.retriever import RetrievedChunk def _test_settings(tmp_path) -> Settings: return Settings( llm_provider="ollama", chroma_persist_directory=str(tmp_path / "chroma"), audit_db_path=str(tmp_path / "audit.db"), jobs_db_path=str(tmp_path / "jobs.db"), top_k_results=3, ) @pytest.fixture def client(tmp_path, monkeypatch): settings = _test_settings(tmp_path) monkeypatch.setattr("api.main.get_settings", lambda: settings) monkeypatch.setattr("api.routes.query.get_settings", lambda: settings) with TestClient(app) as test_client: yield test_client def test_ask_returns_grounded_answer_with_sources(client, monkeypatch): chunks = [ RetrievedChunk( text="Audi has strategic EV expansion plans.", score=0.92, source="strategy.md", page=1, chunk_index=0, ) ] monkeypatch.setattr("api.routes.query.create_embedding_function", lambda: object()) monkeypatch.setattr("api.routes.query.get_vector_store", lambda **_: object()) monkeypatch.setattr("api.routes.query.retrieve_chunks", lambda *_: chunks) monkeypatch.setattr("api.routes.query.answer_with_grounding", lambda *_: ("Audi is expanding EV investment.", 42)) monkeypatch.setattr("api.routes.query.persist_query_audit", AsyncMock(return_value="evt-1")) response = client.post( "/query/ask", json={ "question": "What is Audi doing in EV markets worldwide?", "collection_name": "default", "top_k": 3, "user_id": "tester", }, ) assert response.status_code == 200 body = response.json() assert body["answer"] == "Audi is expanding EV investment." assert "query_id" in body assert body["question"].startswith("What is Audi") assert len(body["sources"]) == 1 assert body["sources"][0]["document_name"] == "strategy.md" assert body["sources"][0]["page_number"] == 1 assert body["tokens_used"] == 42 assert "response_time_ms" in body assert "model_used" in body def test_ask_respects_top_k_in_retrieve_call(client, monkeypatch): captured: dict[str, object] = {} def capture_retrieve(vs, question, k): captured["k"] = k return [] monkeypatch.setattr("api.routes.query.create_embedding_function", lambda: object()) monkeypatch.setattr("api.routes.query.get_vector_store", lambda **_: object()) monkeypatch.setattr("api.routes.query.retrieve_chunks", capture_retrieve) monkeypatch.setattr("api.routes.query.answer_with_grounding", lambda *_: ("No match answer", 0)) monkeypatch.setattr("api.routes.query.persist_query_audit", AsyncMock()) response = client.post( "/query/ask", json={"question": "What is known about the topic here?", "collection_name": "default", "top_k": 7}, ) assert response.status_code == 200 assert captured.get("k") == 7 def test_ask_returns_422_for_invalid_payload(client): response = client.post("/query/ask", json={"collection_name": "default"}) assert response.status_code == 422 def test_ask_returns_422_for_short_question(client): response = client.post( "/query/ask", json={"question": "hi", "collection_name": "default"}, ) assert response.status_code == 422 def test_ask_returns_500_when_retrieval_fails(client, monkeypatch): monkeypatch.setattr("api.routes.query.create_embedding_function", lambda: object()) monkeypatch.setattr("api.routes.query.get_vector_store", lambda **_: object()) monkeypatch.setattr("api.routes.query.retrieve_chunks", lambda *_: (_ for _ in ()).throw(RuntimeError("retrieval failed"))) response = client.post( "/query/ask", json={"question": "What happened in the documents?", "collection_name": "default"}, ) assert response.status_code == 500 assert "retrieval failed" in response.json()["detail"] def test_summarise_returns_500_when_audit_persist_fails(client, monkeypatch): chunks = [ RetrievedChunk( text="Revenue and risks are discussed in the report.", score=0.88, source="report.txt", page=None, chunk_index=2, ) ] monkeypatch.setattr("api.routes.query.create_embedding_function", lambda: object()) monkeypatch.setattr("api.routes.query.get_vector_store", lambda **_: object()) monkeypatch.setattr("api.routes.query.retrieve_chunks", lambda *_: chunks) monkeypatch.setattr("api.routes.query.summarise_with_grounding", lambda *_, **__: ("Summary output", 10)) monkeypatch.setattr("api.routes.query.collection_document_count", lambda *_: 5) monkeypatch.setattr( "api.routes.query.persist_query_audit", AsyncMock(side_effect=RuntimeError("audit write failed")), ) response = client.post( "/query/summarise", json={"collection_name": "default", "focus": "summarise risks", "user_id": "u1"}, ) assert response.status_code == 500 assert "audit write failed" in response.json()["detail"]