Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Test Suite for Maternal Health RAG System | |
| Validates complete RAG pipeline including vector retrieval and response generation | |
| """ | |
| import unittest | |
| import time | |
| from typing import List, Dict, Any | |
| from pathlib import Path | |
| from maternal_health_rag import MaternalHealthRAG, QueryResponse | |
| class TestMaternalHealthRAG(unittest.TestCase): | |
| """Test suite for RAG system functionality""" | |
| def setUpClass(cls): | |
| """Set up test environment""" | |
| print("π Initializing RAG system for testing...") | |
| cls.rag_system = MaternalHealthRAG(use_mock_llm=True) | |
| def test_rag_system_initialization(self): | |
| """Test RAG system initializes correctly""" | |
| self.assertIsNotNone(self.rag_system.vector_store) | |
| self.assertIsNotNone(self.rag_system.llm) | |
| self.assertIsNotNone(self.rag_system.rag_chain) | |
| # Check system status | |
| stats = self.rag_system.get_system_stats() | |
| self.assertEqual(stats['status'], 'initialized') | |
| self.assertGreater(stats['vector_store']['total_chunks'], 0) | |
| def test_basic_query_processing(self): | |
| """Test basic query processing functionality""" | |
| query = "What is magnesium sulfate used for?" | |
| response = self.rag_system.query(query) | |
| # Basic response validation | |
| self.assertIsInstance(response, QueryResponse) | |
| self.assertEqual(response.query, query) | |
| self.assertIsInstance(response.answer, str) | |
| self.assertGreater(len(response.answer), 0) | |
| self.assertGreaterEqual(response.confidence, 0.0) | |
| self.assertLessEqual(response.confidence, 1.0) | |
| self.assertGreater(response.response_time, 0.0) | |
| def test_medical_context_queries(self): | |
| """Test queries with specific medical context""" | |
| test_cases = [ | |
| { | |
| 'query': 'What is the dosage for magnesium sulfate in preeclampsia?', | |
| 'content_types': ['dosage', 'emergency'], | |
| 'expected_keywords': ['magnesium', 'dosage', 'preeclampsia'], | |
| 'min_confidence': 0.3 | |
| }, | |
| { | |
| 'query': 'How to manage postpartum hemorrhage emergency?', | |
| 'content_types': ['emergency', 'maternal'], | |
| 'expected_keywords': ['hemorrhage', 'postpartum', 'emergency'], | |
| 'min_confidence': 0.3 | |
| }, | |
| { | |
| 'query': 'Normal fetal heart rate monitoring procedures', | |
| 'content_types': ['procedure', 'maternal'], | |
| 'expected_keywords': ['fetal', 'heart', 'rate', 'monitoring'], | |
| 'min_confidence': 0.3 | |
| } | |
| ] | |
| for case in test_cases: | |
| with self.subTest(query=case['query']): | |
| response = self.rag_system.query( | |
| case['query'], | |
| content_types=case['content_types'] | |
| ) | |
| # Response quality checks | |
| self.assertGreater(len(response.sources), 0) | |
| self.assertGreaterEqual(response.confidence, case['min_confidence']) | |
| # Check if expected keywords appear in answer or sources | |
| combined_text = response.answer.lower() | |
| if response.sources: | |
| combined_text += ' ' + ' '.join([s.content.lower() for s in response.sources]) | |
| keyword_found = any(keyword in combined_text for keyword in case['expected_keywords']) | |
| self.assertTrue(keyword_found, | |
| f"No expected keywords found for query: {case['query']}") | |
| def test_response_metadata(self): | |
| """Test response metadata is populated correctly""" | |
| query = "What are the signs of preeclampsia?" | |
| response = self.rag_system.query(query) | |
| # Check metadata fields | |
| required_fields = ['num_sources', 'avg_relevance', 'content_types', 'high_importance_sources'] | |
| for field in required_fields: | |
| self.assertIn(field, response.metadata) | |
| # Validate metadata values | |
| self.assertIsInstance(response.metadata['num_sources'], int) | |
| self.assertGreaterEqual(response.metadata['num_sources'], 0) | |
| self.assertIsInstance(response.metadata['avg_relevance'], float) | |
| self.assertIsInstance(response.metadata['content_types'], list) | |
| self.assertIsInstance(response.metadata['high_importance_sources'], int) | |
| def test_confidence_scoring(self): | |
| """Test confidence scoring mechanism""" | |
| # High-confidence query (should match well) | |
| high_conf_query = "magnesium sulfate for preeclampsia" | |
| high_response = self.rag_system.query(high_conf_query) | |
| # Low-confidence query (less specific) | |
| low_conf_query = "medical procedures in general" | |
| low_response = self.rag_system.query(low_conf_query) | |
| # Confidence should be higher for more specific medical queries | |
| self.assertGreaterEqual(high_response.confidence, 0.3) | |
| # Both should have valid confidence scores | |
| self.assertGreaterEqual(high_response.confidence, 0.0) | |
| self.assertLessEqual(high_response.confidence, 1.0) | |
| self.assertGreaterEqual(low_response.confidence, 0.0) | |
| self.assertLessEqual(low_response.confidence, 1.0) | |
| def test_performance_metrics(self): | |
| """Test RAG system performance""" | |
| query = "What is normal labor management?" | |
| # Measure response time | |
| start_time = time.time() | |
| response = self.rag_system.query(query) | |
| actual_time = time.time() - start_time | |
| # Response should be fast (under 2 seconds for mock LLM) | |
| self.assertLess(response.response_time, 2.0) | |
| self.assertLess(actual_time, 3.0) | |
| # Should return relevant sources | |
| self.assertGreater(len(response.sources), 0) | |
| self.assertLessEqual(len(response.sources), 10) # Should be reasonable number | |
| def test_batch_query_processing(self): | |
| """Test batch query processing functionality""" | |
| queries = [ | |
| "What is magnesium sulfate used for?", | |
| "How to manage labor complications?", | |
| "Normal fetal heart rate ranges" | |
| ] | |
| responses = self.rag_system.batch_query(queries) | |
| # Should return same number of responses | |
| self.assertEqual(len(responses), len(queries)) | |
| # Each response should be valid | |
| for i, response in enumerate(responses): | |
| self.assertIsInstance(response, QueryResponse) | |
| self.assertEqual(response.query, queries[i]) | |
| self.assertGreater(len(response.answer), 0) | |
| def test_context_preparation(self): | |
| """Test context preparation from search results""" | |
| query = "preeclampsia management guidelines" | |
| response = self.rag_system.query(query) | |
| # Should have sources for context | |
| if response.sources: | |
| # Context should be prepared from these sources | |
| self.assertGreater(len(response.answer), 20) # Reasonable answer length | |
| # Mock LLM should include safety disclaimer | |
| self.assertIn("healthcare", response.answer.lower()) | |
| def test_error_handling(self): | |
| """Test error handling for edge cases""" | |
| # Empty query | |
| empty_response = self.rag_system.query("") | |
| self.assertIsInstance(empty_response, QueryResponse) | |
| self.assertIsInstance(empty_response.answer, str) | |
| # Very long query | |
| long_query = "What is the management protocol for " + "complicated " * 100 + "pregnancy cases?" | |
| long_response = self.rag_system.query(long_query) | |
| self.assertIsInstance(long_response, QueryResponse) | |
| # Special characters query | |
| special_query = "What is the dosage for mg++ and other electrolytes?" | |
| special_response = self.rag_system.query(special_query) | |
| self.assertIsInstance(special_response, QueryResponse) | |
| def test_medical_safety_responses(self): | |
| """Test that responses include appropriate medical safety disclaimers""" | |
| queries = [ | |
| "What medication should I take for preeclampsia?", | |
| "How much magnesium sulfate should I give?", | |
| "What should I do for emergency bleeding?" | |
| ] | |
| for query in queries: | |
| with self.subTest(query=query): | |
| response = self.rag_system.query(query) | |
| # Should include medical safety language | |
| safety_terms = ['consult', 'healthcare', 'professional', 'medical'] | |
| answer_lower = response.answer.lower() | |
| safety_found = any(term in answer_lower for term in safety_terms) | |
| self.assertTrue(safety_found, | |
| f"No safety disclaimer found in response to: {query}") | |
| def test_system_statistics(self): | |
| """Test system statistics functionality""" | |
| stats = self.rag_system.get_system_stats() | |
| # Check required fields | |
| required_fields = ['vector_store', 'rag_config', 'status'] | |
| for field in required_fields: | |
| self.assertIn(field, stats) | |
| # Check vector store stats | |
| self.assertIn('total_chunks', stats['vector_store']) | |
| self.assertGreater(stats['vector_store']['total_chunks'], 0) | |
| # Check RAG config | |
| self.assertIn('default_k', stats['rag_config']) | |
| self.assertIn('llm_type', stats['rag_config']) | |
| self.assertEqual(stats['rag_config']['llm_type'], 'mock') | |
| def run_comprehensive_rag_tests(): | |
| """Run all RAG tests with detailed reporting""" | |
| print("π§ͺ Running Comprehensive RAG System Tests...") | |
| print("=" * 60) | |
| # Create test suite | |
| loader = unittest.TestLoader() | |
| suite = loader.loadTestsFromTestCase(TestMaternalHealthRAG) | |
| # Run tests with detailed output | |
| runner = unittest.TextTestRunner(verbosity=2) | |
| result = runner.run(suite) | |
| # Print summary | |
| print("\n" + "=" * 60) | |
| print("π RAG TEST SUMMARY:") | |
| print(f" Tests run: {result.testsRun}") | |
| print(f" Failures: {len(result.failures)}") | |
| print(f" Errors: {len(result.errors)}") | |
| if result.wasSuccessful(): | |
| print("β ALL RAG TESTS PASSED! RAG system is production-ready.") | |
| print("\nπ Key Validations Completed:") | |
| print(" β RAG system initialization") | |
| print(" β Query processing with medical context") | |
| print(" β Confidence scoring and metadata") | |
| print(" β Performance under 2 seconds") | |
| print(" β Batch query processing") | |
| print(" β Medical safety disclaimers") | |
| print(" β Error handling and edge cases") | |
| else: | |
| print("β Some RAG tests failed. Check output above for details.") | |
| if result.failures: | |
| print("\nFailures:") | |
| for test, traceback in result.failures: | |
| lines = traceback.strip().split('\n') | |
| error_line = lines[-1] if lines else "Unknown failure" | |
| print(f" - {test}: {error_line}") | |
| if result.errors: | |
| print("\nErrors:") | |
| for test, traceback in result.errors: | |
| lines = traceback.strip().split('\n') | |
| error_line = lines[-1] if lines else "Unknown error" | |
| print(f" - {test}: {error_line}") | |
| return result.wasSuccessful() | |
| if __name__ == "__main__": | |
| success = run_comprehensive_rag_tests() | |
| exit(0 if success else 1) |