VoiceVault / tests /test_api_routes.py
NinjainPJs's picture
Initial release: VoiceVault v1.0.0 — Voice-First RAG Knowledge Agent
85f900d
"""
tests/test_api_routes.py
========================
Integration tests for the FastAPI REST API (server.py + api/routes.py).
These tests exercise the REAL method signatures of all pipeline components
(HybridRetriever, ContextBuilder, AnswerChain, KBManager) without mocking
internal methods — catching exactly the type of runtime AttributeError that
unit tests with mocked pipelines miss.
Only the LLM API call and Whisper inference are mocked (network/slow).
Everything else is real.
Run with: pytest tests/test_api_routes.py -v
"""
from __future__ import annotations
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
import json
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from fastapi.testclient import TestClient
# ------------------------------------------------------------------ #
# Fixtures #
# ------------------------------------------------------------------ #
@pytest.fixture(scope="module")
def kb_manager(tmp_path_factory):
"""Real KBManager backed by a temp SQLite DB."""
from voicevault.kb.kb_manager import KBManager
db = tmp_path_factory.mktemp("db") / "test.db"
return KBManager(db_path=db)
@pytest.fixture(scope="module")
def mock_transcriber():
"""Groq/Whisper transcriber that returns a fixed transcript."""
from voicevault.models import TranscriptResult
t = MagicMock()
t.transcribe.return_value = TranscriptResult(
transcript="what is machine learning",
raw_transcript="what is machine learning",
language="en",
confidence=1.0,
model_used="mock",
latency_ms=10,
query_type="factual",
)
return t
@pytest.fixture(scope="module")
def mock_answer_chain():
"""AnswerChain that returns a fixed GenerationResult."""
from voicevault.generation.answer_chain import GenerationResult
chain = MagicMock()
chain.generate.return_value = GenerationResult(
answer="Machine learning is a subset of AI.",
citations=[],
confidence_level="high",
is_refusal=False,
model_used="mock-llm",
tokens_used=50,
latency_ms=100,
)
return chain
@pytest.fixture(scope="module")
def client(kb_manager, mock_transcriber, mock_answer_chain, tmp_path_factory):
"""TestClient with all singletons injected — real routing, real method calls."""
import api.routes as routes_mod
from voicevault.storage import sqlite_store as db_mod
db = tmp_path_factory.mktemp("db2") / "server.db"
db_mod.initialize_database(db)
routes_mod.init_routes(kb_manager, mock_transcriber, mock_answer_chain, db)
from fastapi import FastAPI
from fastapi.testclient import TestClient as TC
from api.routes import router
app = FastAPI()
app.include_router(router)
return TC(app)
# ------------------------------------------------------------------ #
# KB Endpoints #
# ------------------------------------------------------------------ #
class TestKBEndpoints:
"""Test KB CRUD endpoints with a real KBManager."""
def test_list_kbs_empty(self, client):
r = client.get("/api/kbs")
assert r.status_code == 200
assert isinstance(r.json(), list)
def test_create_kb_success(self, client):
r = client.post("/api/kbs", json={"kb_name": "api-test-kb", "display_name": "API Test KB"})
assert r.status_code == 200
data = r.json()
assert data["ok"] is True
assert data["kb_name"] == "api-test-kb"
def test_create_kb_invalid_name(self, client):
r = client.post("/api/kbs", json={"kb_name": "Bad Name!", "display_name": "Bad"})
assert r.status_code == 400
def test_create_kb_duplicate(self, client):
client.post("/api/kbs", json={"kb_name": "dup-kb", "display_name": "Dup"})
r = client.post("/api/kbs", json={"kb_name": "dup-kb", "display_name": "Dup2"})
assert r.status_code == 400
def test_list_kbs_after_create(self, client):
r = client.get("/api/kbs")
assert r.status_code == 200
names = [kb["kb_name"] for kb in r.json()]
assert "api-test-kb" in names
def test_kb_response_has_required_fields(self, client):
client.post("/api/kbs", json={"kb_name": "field-test", "display_name": "Field Test"})
r = client.get("/api/kbs")
kb = next((k for k in r.json() if k["kb_name"] == "field-test"), None)
assert kb is not None
assert "kb_name" in kb
assert "display_name" in kb
assert "is_protected" in kb
assert "doc_count" in kb
assert "chunk_count" in kb
def test_delete_kb_success(self, client):
client.post("/api/kbs", json={"kb_name": "to-delete", "display_name": "Delete Me"})
r = client.delete("/api/kbs/to-delete")
assert r.status_code == 200
assert r.json()["ok"] is True
def test_delete_kb_nonexistent(self, client):
r = client.delete("/api/kbs/does-not-exist-xyz")
assert r.status_code == 404
# ------------------------------------------------------------------ #
# Ask Endpoint — Real Pipeline Method Calls #
# ------------------------------------------------------------------ #
class TestAskEndpoint:
"""
Tests that exercise the REAL HybridRetriever.retrieve() and
ContextBuilder.build() method signatures — the exact bugs that
mocked unit tests miss.
"""
def test_ask_empty_query_rejected(self, client):
r = client.post("/api/ask", json={"query": "", "kb_names": ["api-test-kb"]})
assert r.status_code == 400
def test_ask_no_kbs_rejected(self, client):
r = client.post("/api/ask", json={"query": "what is ML?", "kb_names": []})
assert r.status_code == 400
def test_ask_calls_hybrid_retriever_retrieve_method(self, client):
"""
Ensures HybridRetriever.retrieve() is called (not .search() which
does not exist on HybridRetriever — that was the runtime bug).
"""
with patch("voicevault.retrieval.hybrid_retriever.HybridRetriever.retrieve",
return_value=[]) as mock_retrieve:
r = client.post("/api/ask", json={
"query": "what is machine learning",
"kb_names": ["api-test-kb"],
})
# retrieve() was called — not search()
mock_retrieve.assert_called_once()
assert r.status_code == 200
def test_ask_calls_context_builder_build_method(self, client):
"""Ensures ContextBuilder.build() is called with the retrieval results."""
with patch("voicevault.retrieval.hybrid_retriever.HybridRetriever.retrieve",
return_value=[]):
with patch("voicevault.retrieval.context_builder.ContextBuilder.build",
return_value=("context text", [])) as mock_build:
r = client.post("/api/ask", json={
"query": "what is machine learning",
"kb_names": ["api-test-kb"],
})
mock_build.assert_called_once()
assert r.status_code == 200
def test_ask_response_has_required_fields(self, client):
with patch("voicevault.retrieval.hybrid_retriever.HybridRetriever.retrieve",
return_value=[]):
r = client.post("/api/ask", json={
"query": "what is machine learning",
"kb_names": ["api-test-kb"],
})
assert r.status_code == 200
data = r.json()
assert "answer" in data
assert "citations" in data
assert "confidence_level" in data
assert "is_refusal" in data
assert "model_used" in data
assert "latency_ms" in data
assert "tts_text" in data
def test_ask_with_history(self, client):
with patch("voicevault.retrieval.hybrid_retriever.HybridRetriever.retrieve",
return_value=[]):
r = client.post("/api/ask", json={
"query": "tell me more",
"kb_names": ["api-test-kb"],
"history": [["what is ML?", "ML is a field of AI."]],
})
assert r.status_code == 200
# ------------------------------------------------------------------ #
# Analytics Endpoint #
# ------------------------------------------------------------------ #
class TestAnalyticsEndpoint:
def test_analytics_returns_stats(self, client):
r = client.get("/api/analytics")
assert r.status_code == 200
data = r.json()
assert "stats" in data
assert "kbs" in data
assert "total_queries" in data["stats"]
assert "avg_latency_ms" in data["stats"]
assert "avg_citation_count" in data["stats"]
assert "queries_by_day" in data["stats"]
def test_analytics_kbs_is_list(self, client):
r = client.get("/api/analytics")
assert isinstance(r.json()["kbs"], list)
# ------------------------------------------------------------------ #
# Transcription Endpoint #
# ------------------------------------------------------------------ #
class TestTranscribeEndpoint:
def test_transcribe_wav_file(self, client, tmp_path):
"""Send a real WAV file and confirm the mock transcriber is called."""
import struct, math
# Create a minimal valid WAV (0.5s silence at 16kHz)
sample_rate = 16000
num_samples = sample_rate // 2
wav_path = tmp_path / "test.wav"
with open(wav_path, "wb") as f:
data_size = num_samples * 2
f.write(b"RIFF")
f.write(struct.pack("<I", 36 + data_size))
f.write(b"WAVEfmt ")
f.write(struct.pack("<IHHIIHH", 16, 1, 1, sample_rate, sample_rate * 2, 2, 16))
f.write(b"data")
f.write(struct.pack("<I", data_size))
f.write(b"\x00" * data_size)
with open(wav_path, "rb") as f:
r = client.post("/api/transcribe", files={"audio": ("test.wav", f, "audio/wav")})
assert r.status_code == 200
data = r.json()
assert "transcript" in data
assert data["transcript"] == "what is machine learning"