Spaces:
Sleeping
Sleeping
| """Tests for RAG retrieval pipelines.""" | |
| import pytest | |
| import tempfile | |
| import os | |
| from core.index import VectorStore | |
| from core.retrieval import BaseRAG, HierarchicalRAG, RAGComparator | |
| def sample_vector_store(): | |
| """Create a sample vector store for testing.""" | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| store = VectorStore(collection_name="test_rag", persist_directory=temp_dir) | |
| chunks = [ | |
| { | |
| "text": "Patient admission requires proper documentation and ID verification.", | |
| "metadata": { | |
| "chunk_id": "p1", | |
| "level1": "Clinical Care", | |
| "level2": "Patient Records", | |
| "doc_type": "protocol" | |
| } | |
| }, | |
| { | |
| "text": "Bank account opening needs KYC compliance and verification.", | |
| "metadata": { | |
| "chunk_id": "p2", | |
| "level1": "Retail Banking", | |
| "level2": "Account Services", | |
| "doc_type": "policy" | |
| } | |
| } | |
| ] | |
| store.add_documents(chunks) | |
| yield store | |
| def test_base_rag_retrieve(sample_vector_store): | |
| """Test Base-RAG retrieval.""" | |
| api_key = os.getenv("OPENAI_API_KEY", "dummy-key-for-testing") | |
| rag = BaseRAG(sample_vector_store, api_key=api_key) | |
| results, time = rag.retrieve("patient admission", n_results=2) | |
| assert len(results) <= 2 | |
| assert time > 0 | |
| assert "document" in results[0] | |
| def test_hierarchical_rag_retrieve(sample_vector_store): | |
| """Test Hierarchical RAG retrieval with filters.""" | |
| api_key = os.getenv("OPENAI_API_KEY", "dummy-key-for-testing") | |
| rag = HierarchicalRAG(sample_vector_store, api_key=api_key) | |
| results, time, filters = rag.retrieve( | |
| "patient admission", | |
| n_results=2, | |
| level1="Clinical Care" | |
| ) | |
| assert len(results) >= 0 | |
| assert time > 0 | |
| assert filters["level1"] == "Clinical Care" | |
| def test_hierarchy_inference(sample_vector_store): | |
| """Test automatic hierarchy inference from query.""" | |
| api_key = os.getenv("OPENAI_API_KEY", "dummy-key-for-testing") | |
| rag = HierarchicalRAG(sample_vector_store, api_key=api_key) | |
| query = "What are the patient admission procedures?" | |
| filters = rag.infer_hierarchy_from_query(query) | |
| assert "level1" in filters | |
| assert "level2" in filters | |
| assert "level3" in filters | |
| assert "doc_type" in filters | |
| def test_base_rag_generate(sample_vector_store): | |
| """Test Base-RAG answer generation.""" | |
| rag = BaseRAG(sample_vector_store) | |
| contexts = ["Patient admission requires ID.", "Documentation is mandatory."] | |
| answer, time = rag.generate("What is required for admission?", contexts, max_tokens=50) | |
| assert len(answer) > 0 | |
| assert time > 0 | |
| assert not answer.startswith("❌") # No error | |
| def test_rag_comparator(sample_vector_store): | |
| """Test RAG comparison.""" | |
| comparator = RAGComparator(sample_vector_store) | |
| result = comparator.compare("What is patient admission?", n_results=2, auto_infer=True) | |
| assert "query" in result | |
| assert "base_rag" in result | |
| assert "hier_rag" in result | |
| assert "speedup" in result | |
| assert result["speedup"] > 0 | |
| def test_error_handling_invalid_api_key(): | |
| """Test error handling with invalid API key.""" | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| store = VectorStore(collection_name="test", persist_directory=temp_dir) | |
| rag = BaseRAG(store, api_key="invalid-key") | |
| contexts = ["Test context"] | |
| answer, time = rag.generate("Test query", contexts) | |
| # Should return error message, not crash | |
| assert "❌" in answer or "Error" in answer | |
| def test_empty_contexts(): | |
| """Test handling of empty contexts.""" | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| store = VectorStore(collection_name="test", persist_directory=temp_dir) | |
| api_key = os.getenv("OPENAI_API_KEY", "dummy-key") | |
| rag = BaseRAG(store, api_key=api_key) | |
| answer, time = rag.generate("Test query", []) | |
| # Should handle gracefully | |
| assert isinstance(answer, str) | |
| assert time >= 0 |