hierarchical-rag-eval / tests /test_retrieval.py
hh786's picture
Deployment of Hierarchical RAG system
c54dcef
"""Tests for RAG retrieval pipelines."""
import pytest
import tempfile
import os
from core.index import VectorStore
from core.retrieval import BaseRAG, HierarchicalRAG, RAGComparator
@pytest.fixture
def sample_vector_store():
"""Create a sample vector store for testing."""
with tempfile.TemporaryDirectory() as temp_dir:
store = VectorStore(collection_name="test_rag", persist_directory=temp_dir)
chunks = [
{
"text": "Patient admission requires proper documentation and ID verification.",
"metadata": {
"chunk_id": "p1",
"level1": "Clinical Care",
"level2": "Patient Records",
"doc_type": "protocol"
}
},
{
"text": "Bank account opening needs KYC compliance and verification.",
"metadata": {
"chunk_id": "p2",
"level1": "Retail Banking",
"level2": "Account Services",
"doc_type": "policy"
}
}
]
store.add_documents(chunks)
yield store
def test_base_rag_retrieve(sample_vector_store):
"""Test Base-RAG retrieval."""
api_key = os.getenv("OPENAI_API_KEY", "dummy-key-for-testing")
rag = BaseRAG(sample_vector_store, api_key=api_key)
results, time = rag.retrieve("patient admission", n_results=2)
assert len(results) <= 2
assert time > 0
assert "document" in results[0]
def test_hierarchical_rag_retrieve(sample_vector_store):
"""Test Hierarchical RAG retrieval with filters."""
api_key = os.getenv("OPENAI_API_KEY", "dummy-key-for-testing")
rag = HierarchicalRAG(sample_vector_store, api_key=api_key)
results, time, filters = rag.retrieve(
"patient admission",
n_results=2,
level1="Clinical Care"
)
assert len(results) >= 0
assert time > 0
assert filters["level1"] == "Clinical Care"
def test_hierarchy_inference(sample_vector_store):
"""Test automatic hierarchy inference from query."""
api_key = os.getenv("OPENAI_API_KEY", "dummy-key-for-testing")
rag = HierarchicalRAG(sample_vector_store, api_key=api_key)
query = "What are the patient admission procedures?"
filters = rag.infer_hierarchy_from_query(query)
assert "level1" in filters
assert "level2" in filters
assert "level3" in filters
assert "doc_type" in filters
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="No OpenAI API key")
def test_base_rag_generate(sample_vector_store):
"""Test Base-RAG answer generation."""
rag = BaseRAG(sample_vector_store)
contexts = ["Patient admission requires ID.", "Documentation is mandatory."]
answer, time = rag.generate("What is required for admission?", contexts, max_tokens=50)
assert len(answer) > 0
assert time > 0
assert not answer.startswith("❌") # No error
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="No OpenAI API key")
def test_rag_comparator(sample_vector_store):
"""Test RAG comparison."""
comparator = RAGComparator(sample_vector_store)
result = comparator.compare("What is patient admission?", n_results=2, auto_infer=True)
assert "query" in result
assert "base_rag" in result
assert "hier_rag" in result
assert "speedup" in result
assert result["speedup"] > 0
def test_error_handling_invalid_api_key():
"""Test error handling with invalid API key."""
with tempfile.TemporaryDirectory() as temp_dir:
store = VectorStore(collection_name="test", persist_directory=temp_dir)
rag = BaseRAG(store, api_key="invalid-key")
contexts = ["Test context"]
answer, time = rag.generate("Test query", contexts)
# Should return error message, not crash
assert "❌" in answer or "Error" in answer
def test_empty_contexts():
"""Test handling of empty contexts."""
with tempfile.TemporaryDirectory() as temp_dir:
store = VectorStore(collection_name="test", persist_directory=temp_dir)
api_key = os.getenv("OPENAI_API_KEY", "dummy-key")
rag = BaseRAG(store, api_key=api_key)
answer, time = rag.generate("Test query", [])
# Should handle gracefully
assert isinstance(answer, str)
assert time >= 0