""" tests/test_api_routes.py ======================== Integration tests for the FastAPI REST API (server.py + api/routes.py). These tests exercise the REAL method signatures of all pipeline components (HybridRetriever, ContextBuilder, AnswerChain, KBManager) without mocking internal methods — catching exactly the type of runtime AttributeError that unit tests with mocked pipelines miss. Only the LLM API call and Whisper inference are mocked (network/slow). Everything else is real. Run with: pytest tests/test_api_routes.py -v """ from __future__ import annotations import os os.environ["CUDA_VISIBLE_DEVICES"] = "-1" import json from pathlib import Path from unittest.mock import MagicMock, patch import pytest from fastapi.testclient import TestClient # ------------------------------------------------------------------ # # Fixtures # # ------------------------------------------------------------------ # @pytest.fixture(scope="module") def kb_manager(tmp_path_factory): """Real KBManager backed by a temp SQLite DB.""" from voicevault.kb.kb_manager import KBManager db = tmp_path_factory.mktemp("db") / "test.db" return KBManager(db_path=db) @pytest.fixture(scope="module") def mock_transcriber(): """Groq/Whisper transcriber that returns a fixed transcript.""" from voicevault.models import TranscriptResult t = MagicMock() t.transcribe.return_value = TranscriptResult( transcript="what is machine learning", raw_transcript="what is machine learning", language="en", confidence=1.0, model_used="mock", latency_ms=10, query_type="factual", ) return t @pytest.fixture(scope="module") def mock_answer_chain(): """AnswerChain that returns a fixed GenerationResult.""" from voicevault.generation.answer_chain import GenerationResult chain = MagicMock() chain.generate.return_value = GenerationResult( answer="Machine learning is a subset of AI.", citations=[], confidence_level="high", is_refusal=False, model_used="mock-llm", tokens_used=50, latency_ms=100, ) return chain @pytest.fixture(scope="module") def client(kb_manager, mock_transcriber, mock_answer_chain, tmp_path_factory): """TestClient with all singletons injected — real routing, real method calls.""" import api.routes as routes_mod from voicevault.storage import sqlite_store as db_mod db = tmp_path_factory.mktemp("db2") / "server.db" db_mod.initialize_database(db) routes_mod.init_routes(kb_manager, mock_transcriber, mock_answer_chain, db) from fastapi import FastAPI from fastapi.testclient import TestClient as TC from api.routes import router app = FastAPI() app.include_router(router) return TC(app) # ------------------------------------------------------------------ # # KB Endpoints # # ------------------------------------------------------------------ # class TestKBEndpoints: """Test KB CRUD endpoints with a real KBManager.""" def test_list_kbs_empty(self, client): r = client.get("/api/kbs") assert r.status_code == 200 assert isinstance(r.json(), list) def test_create_kb_success(self, client): r = client.post("/api/kbs", json={"kb_name": "api-test-kb", "display_name": "API Test KB"}) assert r.status_code == 200 data = r.json() assert data["ok"] is True assert data["kb_name"] == "api-test-kb" def test_create_kb_invalid_name(self, client): r = client.post("/api/kbs", json={"kb_name": "Bad Name!", "display_name": "Bad"}) assert r.status_code == 400 def test_create_kb_duplicate(self, client): client.post("/api/kbs", json={"kb_name": "dup-kb", "display_name": "Dup"}) r = client.post("/api/kbs", json={"kb_name": "dup-kb", "display_name": "Dup2"}) assert r.status_code == 400 def test_list_kbs_after_create(self, client): r = client.get("/api/kbs") assert r.status_code == 200 names = [kb["kb_name"] for kb in r.json()] assert "api-test-kb" in names def test_kb_response_has_required_fields(self, client): client.post("/api/kbs", json={"kb_name": "field-test", "display_name": "Field Test"}) r = client.get("/api/kbs") kb = next((k for k in r.json() if k["kb_name"] == "field-test"), None) assert kb is not None assert "kb_name" in kb assert "display_name" in kb assert "is_protected" in kb assert "doc_count" in kb assert "chunk_count" in kb def test_delete_kb_success(self, client): client.post("/api/kbs", json={"kb_name": "to-delete", "display_name": "Delete Me"}) r = client.delete("/api/kbs/to-delete") assert r.status_code == 200 assert r.json()["ok"] is True def test_delete_kb_nonexistent(self, client): r = client.delete("/api/kbs/does-not-exist-xyz") assert r.status_code == 404 # ------------------------------------------------------------------ # # Ask Endpoint — Real Pipeline Method Calls # # ------------------------------------------------------------------ # class TestAskEndpoint: """ Tests that exercise the REAL HybridRetriever.retrieve() and ContextBuilder.build() method signatures — the exact bugs that mocked unit tests miss. """ def test_ask_empty_query_rejected(self, client): r = client.post("/api/ask", json={"query": "", "kb_names": ["api-test-kb"]}) assert r.status_code == 400 def test_ask_no_kbs_rejected(self, client): r = client.post("/api/ask", json={"query": "what is ML?", "kb_names": []}) assert r.status_code == 400 def test_ask_calls_hybrid_retriever_retrieve_method(self, client): """ Ensures HybridRetriever.retrieve() is called (not .search() which does not exist on HybridRetriever — that was the runtime bug). """ with patch("voicevault.retrieval.hybrid_retriever.HybridRetriever.retrieve", return_value=[]) as mock_retrieve: r = client.post("/api/ask", json={ "query": "what is machine learning", "kb_names": ["api-test-kb"], }) # retrieve() was called — not search() mock_retrieve.assert_called_once() assert r.status_code == 200 def test_ask_calls_context_builder_build_method(self, client): """Ensures ContextBuilder.build() is called with the retrieval results.""" with patch("voicevault.retrieval.hybrid_retriever.HybridRetriever.retrieve", return_value=[]): with patch("voicevault.retrieval.context_builder.ContextBuilder.build", return_value=("context text", [])) as mock_build: r = client.post("/api/ask", json={ "query": "what is machine learning", "kb_names": ["api-test-kb"], }) mock_build.assert_called_once() assert r.status_code == 200 def test_ask_response_has_required_fields(self, client): with patch("voicevault.retrieval.hybrid_retriever.HybridRetriever.retrieve", return_value=[]): r = client.post("/api/ask", json={ "query": "what is machine learning", "kb_names": ["api-test-kb"], }) assert r.status_code == 200 data = r.json() assert "answer" in data assert "citations" in data assert "confidence_level" in data assert "is_refusal" in data assert "model_used" in data assert "latency_ms" in data assert "tts_text" in data def test_ask_with_history(self, client): with patch("voicevault.retrieval.hybrid_retriever.HybridRetriever.retrieve", return_value=[]): r = client.post("/api/ask", json={ "query": "tell me more", "kb_names": ["api-test-kb"], "history": [["what is ML?", "ML is a field of AI."]], }) assert r.status_code == 200 # ------------------------------------------------------------------ # # Analytics Endpoint # # ------------------------------------------------------------------ # class TestAnalyticsEndpoint: def test_analytics_returns_stats(self, client): r = client.get("/api/analytics") assert r.status_code == 200 data = r.json() assert "stats" in data assert "kbs" in data assert "total_queries" in data["stats"] assert "avg_latency_ms" in data["stats"] assert "avg_citation_count" in data["stats"] assert "queries_by_day" in data["stats"] def test_analytics_kbs_is_list(self, client): r = client.get("/api/analytics") assert isinstance(r.json()["kbs"], list) # ------------------------------------------------------------------ # # Transcription Endpoint # # ------------------------------------------------------------------ # class TestTranscribeEndpoint: def test_transcribe_wav_file(self, client, tmp_path): """Send a real WAV file and confirm the mock transcriber is called.""" import struct, math # Create a minimal valid WAV (0.5s silence at 16kHz) sample_rate = 16000 num_samples = sample_rate // 2 wav_path = tmp_path / "test.wav" with open(wav_path, "wb") as f: data_size = num_samples * 2 f.write(b"RIFF") f.write(struct.pack("