plutoV2_miniProject_3rd-yr / mp1 /test_server.py
ayushKishor's picture
Add Pluto memory layer and pipeline fixes
23cdeed
import asyncio
import json
from fastapi.testclient import TestClient
from pluto.doc_index import ChunkMeta, DocIndex
from pluto.models import ChunkPlan, ChunkType, FinalAnswer, FinalOutput, Priority, TraceSummary
import pluto.server as server
def test_server_blocks_run_while_background_understanding(monkeypatch):
temp_index = DocIndex()
temp_index.register_doc(
doc_id="paper",
filename="paper.md",
chunks=["chunk text"],
chunk_meta=[ChunkMeta(chunk_id="C0", chunk_type="text", mode="MODE_REASONING")],
)
temp_index.mark_processing("paper")
monkeypatch.setattr(server, "_doc_index", temp_index)
client = TestClient(server.app)
response = client.post("/api/run", json={"query": "what is this paper about"})
assert response.status_code == 409
payload = response.json()
assert "Please wait" in payload["error"]
assert payload["processing_docs"] == ["paper"]
def test_server_compare_returns_json_error_payload(monkeypatch):
import benchmark.compare as compare_module
class BrokenRunner:
def __init__(self, *args, **kwargs):
pass
def compare(self, query: str, selected_doc_ids=None, detail_level="standard"):
raise RuntimeError("benchmark exploded")
monkeypatch.setattr(compare_module, "ComparisonRunner", BrokenRunner)
client = TestClient(server.app)
response = client.post("/api/compare", json={"query": "what is this paper about"})
assert response.status_code == 200
assert response.json()["error"] == "Benchmark error: benchmark exploded"
def test_server_run_forwards_selected_docs_and_detail_level(monkeypatch):
recorded = {}
class FakeCache:
def stats(self):
return {"hits": 0, "misses": 0}
class FakeRunner:
def __init__(self, *args, **kwargs):
self.cache = FakeCache()
def on_progress(self, callback):
recorded["progress_callback_registered"] = callable(callback)
def run(self, query: str, selected_doc_ids=None, detail_level="standard"):
recorded["query"] = query
recorded["selected_doc_ids"] = selected_doc_ids
recorded["detail_level"] = detail_level
return FinalOutput(
final_answer=FinalAnswer(response="ok", sections=[]),
evidence=[],
trace_summary=TraceSummary(),
confidence=0.9,
)
monkeypatch.setattr(server, "PipelineRunner", FakeRunner)
monkeypatch.setattr(server, "_doc_index", DocIndex())
client = TestClient(server.app)
response = client.post(
"/api/run",
json={
"query": "summarize this",
"selected_doc_ids": ["paper_a"],
"detail_level": "detailed",
},
)
assert response.status_code == 200
assert response.json()["session_id"]
assert recorded["progress_callback_registered"] is True
assert recorded["query"] == "summarize this"
assert recorded["selected_doc_ids"] == ["paper_a"]
assert recorded["detail_level"] == "detailed"
def test_server_compare_forwards_selected_docs_and_detail_level(monkeypatch):
import benchmark.compare as compare_module
recorded = {}
class FakeRunner:
def __init__(self, *args, **kwargs):
pass
def compare(self, query: str, selected_doc_ids=None, detail_level="standard"):
recorded["query"] = query
recorded["selected_doc_ids"] = selected_doc_ids
recorded["detail_level"] = detail_level
return {
"query": query,
"pluto": {"confidence": 1.0},
"baseline": {"confidence": 0.5},
"winner": "Pluto",
}
monkeypatch.setattr(compare_module, "ComparisonRunner", FakeRunner)
monkeypatch.setattr(server, "_doc_index", DocIndex())
client = TestClient(server.app)
response = client.post(
"/api/compare",
json={
"query": "summarize this",
"selected_doc_ids": ["paper_a", "paper_b"],
"detail_level": "detailed",
},
)
assert response.status_code == 200
assert response.json()["winner"] == "Pluto"
assert recorded["query"] == "summarize this"
assert recorded["selected_doc_ids"] == ["paper_a", "paper_b"]
assert recorded["detail_level"] == "detailed"
def test_server_exposes_processed_docs_as_ready_even_if_status_is_stale(monkeypatch):
temp_index = DocIndex()
temp_index.register_doc(
doc_id="agentic_ai",
filename="agentic ai.pdf",
chunks=["chunk text"],
chunk_meta=[ChunkMeta(chunk_id="C0", chunk_type="text", mode="MODE_REASONING")],
)
temp_index.set_overview("agentic_ai", "overview text")
temp_index._docs["agentic_ai"].processing_status = "understanding"
monkeypatch.setattr(server, "_doc_index", temp_index)
client = TestClient(server.app)
status_response = client.get("/api/doc-status/agentic_ai")
corpus_response = client.get("/api/corpus")
assert status_response.status_code == 200
assert status_response.json()["status"] == "ready"
documents = corpus_response.json()["documents"]
agentic_ai = next(document for document in documents if document["doc_id"] == "agentic_ai")
assert agentic_ai["processing_status"] == "ready"
assert agentic_ai["is_processed"] is True
def test_stream_progress_serializes_pydantic_payloads(monkeypatch):
session_id = "test-session"
queue = asyncio.Queue()
monkeypatch.setattr(server, "session_queues", {session_id: queue})
monkeypatch.setattr(server, "session_results", {session_id: {"ok": True}})
monkeypatch.setattr(server, "session_cleanup_tasks", {})
queue.put_nowait({
"stage": "done",
"status": "complete",
"payload": {
"plan": [
ChunkPlan(
doc_id="paper",
chunk_id="C0",
where="chunk 0",
chunk_type=ChunkType.TEXT,
mode="MODE_REASONING",
priority=Priority.HIGH,
task="Extract facts",
)
]
},
})
client = TestClient(server.app)
with client.stream("GET", f"/api/stream?session_id={session_id}") as response:
body = b"".join(response.iter_raw()).decode("utf-8")
assert response.status_code == 200
assert "ChunkPlan" not in body
payload = json.loads(body.removeprefix("data: ").strip())
assert payload["payload"]["plan"][0]["doc_id"] == "paper"
assert payload["payload"]["plan"][0]["chunk_type"] == "text"
assert session_id in server.session_queues
assert session_id in server.session_results
def test_stream_progress_is_session_scoped(monkeypatch):
first = asyncio.Queue()
second = asyncio.Queue()
first.put_nowait({"stage": "done", "status": "complete", "session_id": "first"})
second.put_nowait({"stage": "done", "status": "complete", "session_id": "second"})
monkeypatch.setattr(server, "session_queues", {"first": first, "second": second})
monkeypatch.setattr(server, "session_results", {"first": {}, "second": {}})
monkeypatch.setattr(server, "session_cleanup_tasks", {})
client = TestClient(server.app)
with client.stream("GET", "/api/stream?session_id=second") as response:
body = b"".join(response.iter_raw()).decode("utf-8")
payload = json.loads(body.removeprefix("data: ").strip())
assert payload["session_id"] == "second"
assert "first" in server.session_queues
assert "second" in server.session_queues
def test_session_cleanup_is_delayed(monkeypatch):
async def run_check():
session_id = "cleanup-session"
queue = asyncio.Queue()
monkeypatch.setattr(server, "SESSION_CLEANUP_DELAY_SECONDS", 0.01)
monkeypatch.setattr(server, "session_queues", {session_id: queue})
monkeypatch.setattr(server, "session_results", {session_id: {"ok": True}})
monkeypatch.setattr(server, "session_cleanup_tasks", {})
server._schedule_session_cleanup(session_id, queue)
assert session_id in server.session_queues
assert session_id in server.session_results
await asyncio.sleep(0.05)
assert session_id not in server.session_queues
assert session_id not in server.session_results
asyncio.run(run_check())