Spaces:
Running
Running
| """ | |
| tests/test_api_routes.py | |
| ======================== | |
| Integration tests for the FastAPI REST API (server.py + api/routes.py). | |
| These tests exercise the REAL method signatures of all pipeline components | |
| (HybridRetriever, ContextBuilder, AnswerChain, KBManager) without mocking | |
| internal methods — catching exactly the type of runtime AttributeError that | |
| unit tests with mocked pipelines miss. | |
| Only the LLM API call and Whisper inference are mocked (network/slow). | |
| Everything else is real. | |
| Run with: pytest tests/test_api_routes.py -v | |
| """ | |
| from __future__ import annotations | |
| import os | |
| os.environ["CUDA_VISIBLE_DEVICES"] = "-1" | |
| import json | |
| from pathlib import Path | |
| from unittest.mock import MagicMock, patch | |
| import pytest | |
| from fastapi.testclient import TestClient | |
| # ------------------------------------------------------------------ # | |
| # Fixtures # | |
| # ------------------------------------------------------------------ # | |
| def kb_manager(tmp_path_factory): | |
| """Real KBManager backed by a temp SQLite DB.""" | |
| from voicevault.kb.kb_manager import KBManager | |
| db = tmp_path_factory.mktemp("db") / "test.db" | |
| return KBManager(db_path=db) | |
| def mock_transcriber(): | |
| """Groq/Whisper transcriber that returns a fixed transcript.""" | |
| from voicevault.models import TranscriptResult | |
| t = MagicMock() | |
| t.transcribe.return_value = TranscriptResult( | |
| transcript="what is machine learning", | |
| raw_transcript="what is machine learning", | |
| language="en", | |
| confidence=1.0, | |
| model_used="mock", | |
| latency_ms=10, | |
| query_type="factual", | |
| ) | |
| return t | |
| def mock_answer_chain(): | |
| """AnswerChain that returns a fixed GenerationResult.""" | |
| from voicevault.generation.answer_chain import GenerationResult | |
| chain = MagicMock() | |
| chain.generate.return_value = GenerationResult( | |
| answer="Machine learning is a subset of AI.", | |
| citations=[], | |
| confidence_level="high", | |
| is_refusal=False, | |
| model_used="mock-llm", | |
| tokens_used=50, | |
| latency_ms=100, | |
| ) | |
| return chain | |
| def client(kb_manager, mock_transcriber, mock_answer_chain, tmp_path_factory): | |
| """TestClient with all singletons injected — real routing, real method calls.""" | |
| import api.routes as routes_mod | |
| from voicevault.storage import sqlite_store as db_mod | |
| db = tmp_path_factory.mktemp("db2") / "server.db" | |
| db_mod.initialize_database(db) | |
| routes_mod.init_routes(kb_manager, mock_transcriber, mock_answer_chain, db) | |
| from fastapi import FastAPI | |
| from fastapi.testclient import TestClient as TC | |
| from api.routes import router | |
| app = FastAPI() | |
| app.include_router(router) | |
| return TC(app) | |
| # ------------------------------------------------------------------ # | |
| # KB Endpoints # | |
| # ------------------------------------------------------------------ # | |
| class TestKBEndpoints: | |
| """Test KB CRUD endpoints with a real KBManager.""" | |
| def test_list_kbs_empty(self, client): | |
| r = client.get("/api/kbs") | |
| assert r.status_code == 200 | |
| assert isinstance(r.json(), list) | |
| def test_create_kb_success(self, client): | |
| r = client.post("/api/kbs", json={"kb_name": "api-test-kb", "display_name": "API Test KB"}) | |
| assert r.status_code == 200 | |
| data = r.json() | |
| assert data["ok"] is True | |
| assert data["kb_name"] == "api-test-kb" | |
| def test_create_kb_invalid_name(self, client): | |
| r = client.post("/api/kbs", json={"kb_name": "Bad Name!", "display_name": "Bad"}) | |
| assert r.status_code == 400 | |
| def test_create_kb_duplicate(self, client): | |
| client.post("/api/kbs", json={"kb_name": "dup-kb", "display_name": "Dup"}) | |
| r = client.post("/api/kbs", json={"kb_name": "dup-kb", "display_name": "Dup2"}) | |
| assert r.status_code == 400 | |
| def test_list_kbs_after_create(self, client): | |
| r = client.get("/api/kbs") | |
| assert r.status_code == 200 | |
| names = [kb["kb_name"] for kb in r.json()] | |
| assert "api-test-kb" in names | |
| def test_kb_response_has_required_fields(self, client): | |
| client.post("/api/kbs", json={"kb_name": "field-test", "display_name": "Field Test"}) | |
| r = client.get("/api/kbs") | |
| kb = next((k for k in r.json() if k["kb_name"] == "field-test"), None) | |
| assert kb is not None | |
| assert "kb_name" in kb | |
| assert "display_name" in kb | |
| assert "is_protected" in kb | |
| assert "doc_count" in kb | |
| assert "chunk_count" in kb | |
| def test_delete_kb_success(self, client): | |
| client.post("/api/kbs", json={"kb_name": "to-delete", "display_name": "Delete Me"}) | |
| r = client.delete("/api/kbs/to-delete") | |
| assert r.status_code == 200 | |
| assert r.json()["ok"] is True | |
| def test_delete_kb_nonexistent(self, client): | |
| r = client.delete("/api/kbs/does-not-exist-xyz") | |
| assert r.status_code == 404 | |
| # ------------------------------------------------------------------ # | |
| # Ask Endpoint — Real Pipeline Method Calls # | |
| # ------------------------------------------------------------------ # | |
| class TestAskEndpoint: | |
| """ | |
| Tests that exercise the REAL HybridRetriever.retrieve() and | |
| ContextBuilder.build() method signatures — the exact bugs that | |
| mocked unit tests miss. | |
| """ | |
| def test_ask_empty_query_rejected(self, client): | |
| r = client.post("/api/ask", json={"query": "", "kb_names": ["api-test-kb"]}) | |
| assert r.status_code == 400 | |
| def test_ask_no_kbs_rejected(self, client): | |
| r = client.post("/api/ask", json={"query": "what is ML?", "kb_names": []}) | |
| assert r.status_code == 400 | |
| def test_ask_calls_hybrid_retriever_retrieve_method(self, client): | |
| """ | |
| Ensures HybridRetriever.retrieve() is called (not .search() which | |
| does not exist on HybridRetriever — that was the runtime bug). | |
| """ | |
| with patch("voicevault.retrieval.hybrid_retriever.HybridRetriever.retrieve", | |
| return_value=[]) as mock_retrieve: | |
| r = client.post("/api/ask", json={ | |
| "query": "what is machine learning", | |
| "kb_names": ["api-test-kb"], | |
| }) | |
| # retrieve() was called — not search() | |
| mock_retrieve.assert_called_once() | |
| assert r.status_code == 200 | |
| def test_ask_calls_context_builder_build_method(self, client): | |
| """Ensures ContextBuilder.build() is called with the retrieval results.""" | |
| with patch("voicevault.retrieval.hybrid_retriever.HybridRetriever.retrieve", | |
| return_value=[]): | |
| with patch("voicevault.retrieval.context_builder.ContextBuilder.build", | |
| return_value=("context text", [])) as mock_build: | |
| r = client.post("/api/ask", json={ | |
| "query": "what is machine learning", | |
| "kb_names": ["api-test-kb"], | |
| }) | |
| mock_build.assert_called_once() | |
| assert r.status_code == 200 | |
| def test_ask_response_has_required_fields(self, client): | |
| with patch("voicevault.retrieval.hybrid_retriever.HybridRetriever.retrieve", | |
| return_value=[]): | |
| r = client.post("/api/ask", json={ | |
| "query": "what is machine learning", | |
| "kb_names": ["api-test-kb"], | |
| }) | |
| assert r.status_code == 200 | |
| data = r.json() | |
| assert "answer" in data | |
| assert "citations" in data | |
| assert "confidence_level" in data | |
| assert "is_refusal" in data | |
| assert "model_used" in data | |
| assert "latency_ms" in data | |
| assert "tts_text" in data | |
| def test_ask_with_history(self, client): | |
| with patch("voicevault.retrieval.hybrid_retriever.HybridRetriever.retrieve", | |
| return_value=[]): | |
| r = client.post("/api/ask", json={ | |
| "query": "tell me more", | |
| "kb_names": ["api-test-kb"], | |
| "history": [["what is ML?", "ML is a field of AI."]], | |
| }) | |
| assert r.status_code == 200 | |
| # ------------------------------------------------------------------ # | |
| # Analytics Endpoint # | |
| # ------------------------------------------------------------------ # | |
| class TestAnalyticsEndpoint: | |
| def test_analytics_returns_stats(self, client): | |
| r = client.get("/api/analytics") | |
| assert r.status_code == 200 | |
| data = r.json() | |
| assert "stats" in data | |
| assert "kbs" in data | |
| assert "total_queries" in data["stats"] | |
| assert "avg_latency_ms" in data["stats"] | |
| assert "avg_citation_count" in data["stats"] | |
| assert "queries_by_day" in data["stats"] | |
| def test_analytics_kbs_is_list(self, client): | |
| r = client.get("/api/analytics") | |
| assert isinstance(r.json()["kbs"], list) | |
| # ------------------------------------------------------------------ # | |
| # Transcription Endpoint # | |
| # ------------------------------------------------------------------ # | |
| class TestTranscribeEndpoint: | |
| def test_transcribe_wav_file(self, client, tmp_path): | |
| """Send a real WAV file and confirm the mock transcriber is called.""" | |
| import struct, math | |
| # Create a minimal valid WAV (0.5s silence at 16kHz) | |
| sample_rate = 16000 | |
| num_samples = sample_rate // 2 | |
| wav_path = tmp_path / "test.wav" | |
| with open(wav_path, "wb") as f: | |
| data_size = num_samples * 2 | |
| f.write(b"RIFF") | |
| f.write(struct.pack("<I", 36 + data_size)) | |
| f.write(b"WAVEfmt ") | |
| f.write(struct.pack("<IHHIIHH", 16, 1, 1, sample_rate, sample_rate * 2, 2, 16)) | |
| f.write(b"data") | |
| f.write(struct.pack("<I", data_size)) | |
| f.write(b"\x00" * data_size) | |
| with open(wav_path, "rb") as f: | |
| r = client.post("/api/transcribe", files={"audio": ("test.wav", f, "audio/wav")}) | |
| assert r.status_code == 200 | |
| data = r.json() | |
| assert "transcript" in data | |
| assert data["transcript"] == "what is machine learning" | |