"""Tests for RAG retrieval pipelines.""" import pytest import tempfile import os from core.index import VectorStore from core.retrieval import BaseRAG, HierarchicalRAG, RAGComparator @pytest.fixture def sample_vector_store(): """Create a sample vector store for testing.""" with tempfile.TemporaryDirectory() as temp_dir: store = VectorStore(collection_name="test_rag", persist_directory=temp_dir) chunks = [ { "text": "Patient admission requires proper documentation and ID verification.", "metadata": { "chunk_id": "p1", "level1": "Clinical Care", "level2": "Patient Records", "doc_type": "protocol" } }, { "text": "Bank account opening needs KYC compliance and verification.", "metadata": { "chunk_id": "p2", "level1": "Retail Banking", "level2": "Account Services", "doc_type": "policy" } } ] store.add_documents(chunks) yield store def test_base_rag_retrieve(sample_vector_store): """Test Base-RAG retrieval.""" api_key = os.getenv("OPENAI_API_KEY", "dummy-key-for-testing") rag = BaseRAG(sample_vector_store, api_key=api_key) results, time = rag.retrieve("patient admission", n_results=2) assert len(results) <= 2 assert time > 0 assert "document" in results[0] def test_hierarchical_rag_retrieve(sample_vector_store): """Test Hierarchical RAG retrieval with filters.""" api_key = os.getenv("OPENAI_API_KEY", "dummy-key-for-testing") rag = HierarchicalRAG(sample_vector_store, api_key=api_key) results, time, filters = rag.retrieve( "patient admission", n_results=2, level1="Clinical Care" ) assert len(results) >= 0 assert time > 0 assert filters["level1"] == "Clinical Care" def test_hierarchy_inference(sample_vector_store): """Test automatic hierarchy inference from query.""" api_key = os.getenv("OPENAI_API_KEY", "dummy-key-for-testing") rag = HierarchicalRAG(sample_vector_store, api_key=api_key) query = "What are the patient admission procedures?" filters = rag.infer_hierarchy_from_query(query) assert "level1" in filters assert "level2" in filters assert "level3" in filters assert "doc_type" in filters @pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="No OpenAI API key") def test_base_rag_generate(sample_vector_store): """Test Base-RAG answer generation.""" rag = BaseRAG(sample_vector_store) contexts = ["Patient admission requires ID.", "Documentation is mandatory."] answer, time = rag.generate("What is required for admission?", contexts, max_tokens=50) assert len(answer) > 0 assert time > 0 assert not answer.startswith("❌") # No error @pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="No OpenAI API key") def test_rag_comparator(sample_vector_store): """Test RAG comparison.""" comparator = RAGComparator(sample_vector_store) result = comparator.compare("What is patient admission?", n_results=2, auto_infer=True) assert "query" in result assert "base_rag" in result assert "hier_rag" in result assert "speedup" in result assert result["speedup"] > 0 def test_error_handling_invalid_api_key(): """Test error handling with invalid API key.""" with tempfile.TemporaryDirectory() as temp_dir: store = VectorStore(collection_name="test", persist_directory=temp_dir) rag = BaseRAG(store, api_key="invalid-key") contexts = ["Test context"] answer, time = rag.generate("Test query", contexts) # Should return error message, not crash assert "❌" in answer or "Error" in answer def test_empty_contexts(): """Test handling of empty contexts.""" with tempfile.TemporaryDirectory() as temp_dir: store = VectorStore(collection_name="test", persist_directory=temp_dir) api_key = os.getenv("OPENAI_API_KEY", "dummy-key") rag = BaseRAG(store, api_key=api_key) answer, time = rag.generate("Test query", []) # Should handle gracefully assert isinstance(answer, str) assert time >= 0