File size: 5,943 Bytes
7c46845
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
from unittest.mock import MagicMock

from app.rag import retriever
from app.models import Document


# ── Retrieval: document_ids reaches both vector and BM25 (dev's direct-call retrieve) ──

def _mock_db(monkeypatch, doc_rows):
    mock_db = MagicMock()
    mock_db.__enter__.return_value = mock_db
    mock_query = MagicMock()
    mock_db.query.return_value = mock_query
    mock_query.filter.return_value.all.return_value = doc_rows
    monkeypatch.setattr("app.database.SessionLocal", lambda: mock_db)


def test_retrieve_forwards_document_ids_to_vector_and_bm25(monkeypatch):
    _mock_db(monkeypatch, [("doc-a",), ("doc-b",)])

    seen = {"vector": "unset", "bm25": "unset"}

    monkeypatch.setattr(retriever, "transform_query", lambda _q: ["q"])
    monkeypatch.setattr(retriever, "embed_query", lambda q: f"embedding:{q}")
    monkeypatch.setattr(retriever, "get_reranker", lambda: None)

    def fake_query_chunks(query_embedding, user_id, document_id=None, document_ids=None, top_k=10):
        seen["vector"] = document_ids
        return [{"id": "v1", "text": "vec", "filename": "a.pdf", "page": 1, "score": 0.5}]

    def fake_query_bm25(query, user_id, document_id=None, document_ids=None, top_k=10):
        seen["bm25"] = document_ids
        return [{"id": "b1", "text": "bm", "filename": "b.pdf", "page": 1, "score": 0.5}]

    monkeypatch.setattr(retriever, "query_chunks", fake_query_chunks)
    monkeypatch.setattr("app.rag.bm25.query_bm25", fake_query_bm25)

    retriever.retrieve("question", user_id="user-1", document_ids=["doc-a", "doc-b"])

    assert seen["vector"] == ["doc-a", "doc-b"]
    # bm25 only runs when hybrid search is enabled; if it ran, it must have received the ids
    if seen["bm25"] != "unset":
        assert seen["bm25"] == ["doc-a", "doc-b"]


def test_retrieve_single_document_leaves_document_ids_none(monkeypatch):
    _mock_db(monkeypatch, [("doc-a",)])

    seen = {"vector_id": "unset", "vector_ids": "unset"}

    monkeypatch.setattr(retriever, "transform_query", lambda _q: ["q"])
    monkeypatch.setattr(retriever, "embed_query", lambda q: f"embedding:{q}")
    monkeypatch.setattr(retriever, "get_reranker", lambda: None)

    def fake_query_chunks(query_embedding, user_id, document_id=None, document_ids=None, top_k=10):
        seen["vector_id"] = document_id
        seen["vector_ids"] = document_ids
        return [{"id": "v1", "text": "vec", "filename": "a.pdf", "page": 1, "score": 0.5}]

    monkeypatch.setattr(retriever, "query_chunks", fake_query_chunks)

    retriever.retrieve("question", user_id="user-1", document_id="doc-a")

    assert seen["vector_id"] == "doc-a"
    assert seen["vector_ids"] is None


# ── Prompt: comparison guidance only when more than one document ──

def test_comparison_guidance_present_only_for_multiple_documents(monkeypatch):
    from app.rag import agent
    from app.rag.prompts import MULTI_DOC_COMPARISON_GUIDANCE

    captured = {}

    class FakeLLM:
        def __init__(self, *a, **k):
            pass

    def capture_prompt(llm, tools, prompt):
        captured["template"] = prompt.template
        return "agent"

    monkeypatch.setattr(agent, "get_llm_client", lambda hf_token=None: FakeLLM())
    monkeypatch.setattr(agent, "create_react_agent", capture_prompt)
    monkeypatch.setattr(agent, "AgentExecutor", lambda **kwargs: kwargs)

    agent.get_agent_executor(user_id="user-1", document_ids=["doc-a", "doc-b"])
    assert MULTI_DOC_COMPARISON_GUIDANCE.strip() in captured["template"]

    captured.clear()
    agent.get_agent_executor(user_id="user-1", document_id="doc-a")
    assert MULTI_DOC_COMPARISON_GUIDANCE.strip() not in captured["template"]


# ── Route guard: ownership + readiness for document_ids ──

def test_chat_ask_multi_doc_success(client, auth_headers, ready_document, db_session, user, monkeypatch):
    second = Document(
        user_id=user.id,
        filename="second.txt",
        original_name="second.txt",
        file_size=128,
        status="ready",
    )
    db_session.add(second)
    db_session.commit()
    db_session.refresh(second)

    monkeypatch.setattr(
        "app.routes.chat.generate_answer",
        lambda question, user_id, document_id=None, document_ids=None, **kwargs: {
            "answer": "Across both docs",
            "sources": [],
        },
    )

    response = client.post(
        "/api/v1/chat/ask",
        headers=auth_headers,
        json={"question": "Compare them", "document_ids": [ready_document.id, second.id]},
    )

    assert response.status_code == 200
    assert response.json()["answer"] == "Across both docs"


def test_chat_ask_multi_doc_rejects_missing_document(client, auth_headers, ready_document):
    response = client.post(
        "/api/v1/chat/ask",
        headers=auth_headers,
        json={"question": "Compare", "document_ids": [ready_document.id, "missing-doc-id"]},
    )
    assert response.status_code == 404


def test_chat_ask_multi_doc_rejects_not_ready_document(client, auth_headers, ready_document, pending_document):
    response = client.post(
        "/api/v1/chat/ask",
        headers=auth_headers,
        json={"question": "Compare", "document_ids": [ready_document.id, pending_document.id]},
    )
    assert response.status_code == 400


def test_chat_ask_multi_doc_rejects_other_users_document(client, auth_headers, ready_document, db_session, other_user):
    other_doc = Document(
        user_id=other_user.id,
        filename="other.txt",
        original_name="other.txt",
        file_size=64,
        status="ready",
    )
    db_session.add(other_doc)
    db_session.commit()
    db_session.refresh(other_doc)

    response = client.post(
        "/api/v1/chat/ask",
        headers=auth_headers,
        json={"question": "Compare", "document_ids": [ready_document.id, other_doc.id]},
    )
    # not owned -> treated as missing
    assert response.status_code == 404