| from unittest.mock import AsyncMock |
|
|
| import pytest |
| from fastapi.testclient import TestClient |
|
|
| from api.config import Settings |
| from api.main import app |
| from rag.retriever import RetrievedChunk |
|
|
|
|
| def _test_settings(tmp_path) -> Settings: |
| return Settings( |
| llm_provider="ollama", |
| chroma_persist_directory=str(tmp_path / "chroma"), |
| audit_db_path=str(tmp_path / "audit.db"), |
| jobs_db_path=str(tmp_path / "jobs.db"), |
| top_k_results=3, |
| ) |
|
|
|
|
| @pytest.fixture |
| def client(tmp_path, monkeypatch): |
| settings = _test_settings(tmp_path) |
| monkeypatch.setattr("api.main.get_settings", lambda: settings) |
| monkeypatch.setattr("api.routes.query.get_settings", lambda: settings) |
| with TestClient(app) as test_client: |
| yield test_client |
|
|
|
|
| def test_ask_returns_grounded_answer_with_sources(client, monkeypatch): |
| chunks = [ |
| RetrievedChunk( |
| text="Audi has strategic EV expansion plans.", |
| score=0.92, |
| source="strategy.md", |
| page=1, |
| chunk_index=0, |
| ) |
| ] |
| monkeypatch.setattr("api.routes.query.create_embedding_function", lambda: object()) |
| monkeypatch.setattr("api.routes.query.get_vector_store", lambda **_: object()) |
| monkeypatch.setattr("api.routes.query.retrieve_chunks", lambda *_: chunks) |
| monkeypatch.setattr("api.routes.query.answer_with_grounding", lambda *_: ("Audi is expanding EV investment.", 42)) |
| monkeypatch.setattr("api.routes.query.persist_query_audit", AsyncMock(return_value="evt-1")) |
|
|
| response = client.post( |
| "/query/ask", |
| json={ |
| "question": "What is Audi doing in EV markets worldwide?", |
| "collection_name": "default", |
| "top_k": 3, |
| "user_id": "tester", |
| }, |
| ) |
|
|
| assert response.status_code == 200 |
| body = response.json() |
| assert body["answer"] == "Audi is expanding EV investment." |
| assert "query_id" in body |
| assert body["question"].startswith("What is Audi") |
| assert len(body["sources"]) == 1 |
| assert body["sources"][0]["document_name"] == "strategy.md" |
| assert body["sources"][0]["page_number"] == 1 |
| assert body["tokens_used"] == 42 |
| assert "response_time_ms" in body |
| assert "model_used" in body |
|
|
|
|
| def test_ask_respects_top_k_in_retrieve_call(client, monkeypatch): |
| captured: dict[str, object] = {} |
|
|
| def capture_retrieve(vs, question, k): |
| captured["k"] = k |
| return [] |
|
|
| monkeypatch.setattr("api.routes.query.create_embedding_function", lambda: object()) |
| monkeypatch.setattr("api.routes.query.get_vector_store", lambda **_: object()) |
| monkeypatch.setattr("api.routes.query.retrieve_chunks", capture_retrieve) |
| monkeypatch.setattr("api.routes.query.answer_with_grounding", lambda *_: ("No match answer", 0)) |
| monkeypatch.setattr("api.routes.query.persist_query_audit", AsyncMock()) |
|
|
| response = client.post( |
| "/query/ask", |
| json={"question": "What is known about the topic here?", "collection_name": "default", "top_k": 7}, |
| ) |
| assert response.status_code == 200 |
| assert captured.get("k") == 7 |
|
|
|
|
| def test_ask_returns_422_for_invalid_payload(client): |
| response = client.post("/query/ask", json={"collection_name": "default"}) |
| assert response.status_code == 422 |
|
|
|
|
| def test_ask_returns_422_for_short_question(client): |
| response = client.post( |
| "/query/ask", |
| json={"question": "hi", "collection_name": "default"}, |
| ) |
| assert response.status_code == 422 |
|
|
|
|
| def test_ask_returns_500_when_retrieval_fails(client, monkeypatch): |
| monkeypatch.setattr("api.routes.query.create_embedding_function", lambda: object()) |
| monkeypatch.setattr("api.routes.query.get_vector_store", lambda **_: object()) |
| monkeypatch.setattr("api.routes.query.retrieve_chunks", lambda *_: (_ for _ in ()).throw(RuntimeError("retrieval failed"))) |
|
|
| response = client.post( |
| "/query/ask", |
| json={"question": "What happened in the documents?", "collection_name": "default"}, |
| ) |
|
|
| assert response.status_code == 500 |
| assert "retrieval failed" in response.json()["detail"] |
|
|
|
|
| def test_summarise_returns_500_when_audit_persist_fails(client, monkeypatch): |
| chunks = [ |
| RetrievedChunk( |
| text="Revenue and risks are discussed in the report.", |
| score=0.88, |
| source="report.txt", |
| page=None, |
| chunk_index=2, |
| ) |
| ] |
| monkeypatch.setattr("api.routes.query.create_embedding_function", lambda: object()) |
| monkeypatch.setattr("api.routes.query.get_vector_store", lambda **_: object()) |
| monkeypatch.setattr("api.routes.query.retrieve_chunks", lambda *_: chunks) |
| monkeypatch.setattr("api.routes.query.summarise_with_grounding", lambda *_, **__: ("Summary output", 10)) |
| monkeypatch.setattr("api.routes.query.collection_document_count", lambda *_: 5) |
| monkeypatch.setattr( |
| "api.routes.query.persist_query_audit", |
| AsyncMock(side_effect=RuntimeError("audit write failed")), |
| ) |
|
|
| response = client.post( |
| "/query/summarise", |
| json={"collection_name": "default", "focus": "summarise risks", "user_id": "u1"}, |
| ) |
|
|
| assert response.status_code == 500 |
| assert "audit write failed" in response.json()["detail"] |
|
|