Spaces:
Running
Running
File size: 5,943 Bytes
7c46845 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 | from unittest.mock import MagicMock
from app.rag import retriever
from app.models import Document
# ββ Retrieval: document_ids reaches both vector and BM25 (dev's direct-call retrieve) ββ
def _mock_db(monkeypatch, doc_rows):
mock_db = MagicMock()
mock_db.__enter__.return_value = mock_db
mock_query = MagicMock()
mock_db.query.return_value = mock_query
mock_query.filter.return_value.all.return_value = doc_rows
monkeypatch.setattr("app.database.SessionLocal", lambda: mock_db)
def test_retrieve_forwards_document_ids_to_vector_and_bm25(monkeypatch):
_mock_db(monkeypatch, [("doc-a",), ("doc-b",)])
seen = {"vector": "unset", "bm25": "unset"}
monkeypatch.setattr(retriever, "transform_query", lambda _q: ["q"])
monkeypatch.setattr(retriever, "embed_query", lambda q: f"embedding:{q}")
monkeypatch.setattr(retriever, "get_reranker", lambda: None)
def fake_query_chunks(query_embedding, user_id, document_id=None, document_ids=None, top_k=10):
seen["vector"] = document_ids
return [{"id": "v1", "text": "vec", "filename": "a.pdf", "page": 1, "score": 0.5}]
def fake_query_bm25(query, user_id, document_id=None, document_ids=None, top_k=10):
seen["bm25"] = document_ids
return [{"id": "b1", "text": "bm", "filename": "b.pdf", "page": 1, "score": 0.5}]
monkeypatch.setattr(retriever, "query_chunks", fake_query_chunks)
monkeypatch.setattr("app.rag.bm25.query_bm25", fake_query_bm25)
retriever.retrieve("question", user_id="user-1", document_ids=["doc-a", "doc-b"])
assert seen["vector"] == ["doc-a", "doc-b"]
# bm25 only runs when hybrid search is enabled; if it ran, it must have received the ids
if seen["bm25"] != "unset":
assert seen["bm25"] == ["doc-a", "doc-b"]
def test_retrieve_single_document_leaves_document_ids_none(monkeypatch):
_mock_db(monkeypatch, [("doc-a",)])
seen = {"vector_id": "unset", "vector_ids": "unset"}
monkeypatch.setattr(retriever, "transform_query", lambda _q: ["q"])
monkeypatch.setattr(retriever, "embed_query", lambda q: f"embedding:{q}")
monkeypatch.setattr(retriever, "get_reranker", lambda: None)
def fake_query_chunks(query_embedding, user_id, document_id=None, document_ids=None, top_k=10):
seen["vector_id"] = document_id
seen["vector_ids"] = document_ids
return [{"id": "v1", "text": "vec", "filename": "a.pdf", "page": 1, "score": 0.5}]
monkeypatch.setattr(retriever, "query_chunks", fake_query_chunks)
retriever.retrieve("question", user_id="user-1", document_id="doc-a")
assert seen["vector_id"] == "doc-a"
assert seen["vector_ids"] is None
# ββ Prompt: comparison guidance only when more than one document ββ
def test_comparison_guidance_present_only_for_multiple_documents(monkeypatch):
from app.rag import agent
from app.rag.prompts import MULTI_DOC_COMPARISON_GUIDANCE
captured = {}
class FakeLLM:
def __init__(self, *a, **k):
pass
def capture_prompt(llm, tools, prompt):
captured["template"] = prompt.template
return "agent"
monkeypatch.setattr(agent, "get_llm_client", lambda hf_token=None: FakeLLM())
monkeypatch.setattr(agent, "create_react_agent", capture_prompt)
monkeypatch.setattr(agent, "AgentExecutor", lambda **kwargs: kwargs)
agent.get_agent_executor(user_id="user-1", document_ids=["doc-a", "doc-b"])
assert MULTI_DOC_COMPARISON_GUIDANCE.strip() in captured["template"]
captured.clear()
agent.get_agent_executor(user_id="user-1", document_id="doc-a")
assert MULTI_DOC_COMPARISON_GUIDANCE.strip() not in captured["template"]
# ββ Route guard: ownership + readiness for document_ids ββ
def test_chat_ask_multi_doc_success(client, auth_headers, ready_document, db_session, user, monkeypatch):
second = Document(
user_id=user.id,
filename="second.txt",
original_name="second.txt",
file_size=128,
status="ready",
)
db_session.add(second)
db_session.commit()
db_session.refresh(second)
monkeypatch.setattr(
"app.routes.chat.generate_answer",
lambda question, user_id, document_id=None, document_ids=None, **kwargs: {
"answer": "Across both docs",
"sources": [],
},
)
response = client.post(
"/api/v1/chat/ask",
headers=auth_headers,
json={"question": "Compare them", "document_ids": [ready_document.id, second.id]},
)
assert response.status_code == 200
assert response.json()["answer"] == "Across both docs"
def test_chat_ask_multi_doc_rejects_missing_document(client, auth_headers, ready_document):
response = client.post(
"/api/v1/chat/ask",
headers=auth_headers,
json={"question": "Compare", "document_ids": [ready_document.id, "missing-doc-id"]},
)
assert response.status_code == 404
def test_chat_ask_multi_doc_rejects_not_ready_document(client, auth_headers, ready_document, pending_document):
response = client.post(
"/api/v1/chat/ask",
headers=auth_headers,
json={"question": "Compare", "document_ids": [ready_document.id, pending_document.id]},
)
assert response.status_code == 400
def test_chat_ask_multi_doc_rejects_other_users_document(client, auth_headers, ready_document, db_session, other_user):
other_doc = Document(
user_id=other_user.id,
filename="other.txt",
original_name="other.txt",
file_size=64,
status="ready",
)
db_session.add(other_doc)
db_session.commit()
db_session.refresh(other_doc)
response = client.post(
"/api/v1/chat/ask",
headers=auth_headers,
json={"question": "Compare", "document_ids": [ready_document.id, other_doc.id]},
)
# not owned -> treated as missing
assert response.status_code == 404 |