File size: 5,280 Bytes
fceb91f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a32f9e3
fceb91f
 
 
 
a32f9e3
 
 
 
 
 
fceb91f
 
 
 
 
a32f9e3
 
fceb91f
a32f9e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fceb91f
 
 
 
 
 
 
a32f9e3
 
 
 
 
 
 
 
fceb91f
 
 
 
 
 
 
a32f9e3
fceb91f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a32f9e3
 
fceb91f
 
 
 
 
 
 
a32f9e3
fceb91f
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from unittest.mock import AsyncMock

import pytest
from fastapi.testclient import TestClient

from api.config import Settings
from api.main import app
from rag.retriever import RetrievedChunk


def _test_settings(tmp_path) -> Settings:
    return Settings(
        llm_provider="ollama",
        chroma_persist_directory=str(tmp_path / "chroma"),
        audit_db_path=str(tmp_path / "audit.db"),
        jobs_db_path=str(tmp_path / "jobs.db"),
        top_k_results=3,
    )


@pytest.fixture
def client(tmp_path, monkeypatch):
    settings = _test_settings(tmp_path)
    monkeypatch.setattr("api.main.get_settings", lambda: settings)
    monkeypatch.setattr("api.routes.query.get_settings", lambda: settings)
    with TestClient(app) as test_client:
        yield test_client


def test_ask_returns_grounded_answer_with_sources(client, monkeypatch):
    chunks = [
        RetrievedChunk(
            text="Audi has strategic EV expansion plans.",
            score=0.92,
            source="strategy.md",
            page=1,
            chunk_index=0,
        )
    ]
    monkeypatch.setattr("api.routes.query.create_embedding_function", lambda: object())
    monkeypatch.setattr("api.routes.query.get_vector_store", lambda **_: object())
    monkeypatch.setattr("api.routes.query.retrieve_chunks", lambda *_: chunks)
    monkeypatch.setattr("api.routes.query.answer_with_grounding", lambda *_: ("Audi is expanding EV investment.", 42))
    monkeypatch.setattr("api.routes.query.persist_query_audit", AsyncMock(return_value="evt-1"))

    response = client.post(
        "/query/ask",
        json={
            "question": "What is Audi doing in EV markets worldwide?",
            "collection_name": "default",
            "top_k": 3,
            "user_id": "tester",
        },
    )

    assert response.status_code == 200
    body = response.json()
    assert body["answer"] == "Audi is expanding EV investment."
    assert "query_id" in body
    assert body["question"].startswith("What is Audi")
    assert len(body["sources"]) == 1
    assert body["sources"][0]["document_name"] == "strategy.md"
    assert body["sources"][0]["page_number"] == 1
    assert body["tokens_used"] == 42
    assert "response_time_ms" in body
    assert "model_used" in body


def test_ask_respects_top_k_in_retrieve_call(client, monkeypatch):
    captured: dict[str, object] = {}

    def capture_retrieve(vs, question, k):
        captured["k"] = k
        return []

    monkeypatch.setattr("api.routes.query.create_embedding_function", lambda: object())
    monkeypatch.setattr("api.routes.query.get_vector_store", lambda **_: object())
    monkeypatch.setattr("api.routes.query.retrieve_chunks", capture_retrieve)
    monkeypatch.setattr("api.routes.query.answer_with_grounding", lambda *_: ("No match answer", 0))
    monkeypatch.setattr("api.routes.query.persist_query_audit", AsyncMock())

    response = client.post(
        "/query/ask",
        json={"question": "What is known about the topic here?", "collection_name": "default", "top_k": 7},
    )
    assert response.status_code == 200
    assert captured.get("k") == 7


def test_ask_returns_422_for_invalid_payload(client):
    response = client.post("/query/ask", json={"collection_name": "default"})
    assert response.status_code == 422


def test_ask_returns_422_for_short_question(client):
    response = client.post(
        "/query/ask",
        json={"question": "hi", "collection_name": "default"},
    )
    assert response.status_code == 422


def test_ask_returns_500_when_retrieval_fails(client, monkeypatch):
    monkeypatch.setattr("api.routes.query.create_embedding_function", lambda: object())
    monkeypatch.setattr("api.routes.query.get_vector_store", lambda **_: object())
    monkeypatch.setattr("api.routes.query.retrieve_chunks", lambda *_: (_ for _ in ()).throw(RuntimeError("retrieval failed")))

    response = client.post(
        "/query/ask",
        json={"question": "What happened in the documents?", "collection_name": "default"},
    )

    assert response.status_code == 500
    assert "retrieval failed" in response.json()["detail"]


def test_summarise_returns_500_when_audit_persist_fails(client, monkeypatch):
    chunks = [
        RetrievedChunk(
            text="Revenue and risks are discussed in the report.",
            score=0.88,
            source="report.txt",
            page=None,
            chunk_index=2,
        )
    ]
    monkeypatch.setattr("api.routes.query.create_embedding_function", lambda: object())
    monkeypatch.setattr("api.routes.query.get_vector_store", lambda **_: object())
    monkeypatch.setattr("api.routes.query.retrieve_chunks", lambda *_: chunks)
    monkeypatch.setattr("api.routes.query.summarise_with_grounding", lambda *_, **__: ("Summary output", 10))
    monkeypatch.setattr("api.routes.query.collection_document_count", lambda *_: 5)
    monkeypatch.setattr(
        "api.routes.query.persist_query_audit",
        AsyncMock(side_effect=RuntimeError("audit write failed")),
    )

    response = client.post(
        "/query/summarise",
        json={"collection_name": "default", "focus": "summarise risks", "user_id": "u1"},
    )

    assert response.status_code == 500
    assert "audit write failed" in response.json()["detail"]