Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Test Suite for Maternal Health Vector Store | |
| Validates search functionality, medical context filtering, and performance | |
| """ | |
| import unittest | |
| import time | |
| from pathlib import Path | |
| from vector_store_manager import MaternalHealthVectorStore, SearchResult | |
| class TestMaternalHealthVectorStore(unittest.TestCase): | |
| """Test suite for vector store functionality""" | |
| def setUpClass(cls): | |
| """Set up test environment""" | |
| cls.vector_store = MaternalHealthVectorStore() | |
| # Load existing vector store (should exist from previous run) | |
| if cls.vector_store.index_file.exists(): | |
| print("Loading existing vector store for testing...") | |
| success = cls.vector_store.load_existing_index() | |
| if not success: | |
| print("Failed to load existing index, creating new one...") | |
| cls.vector_store.create_vector_index() | |
| else: | |
| print("Creating vector store for testing...") | |
| cls.vector_store.create_vector_index() | |
| def test_vector_store_initialization(self): | |
| """Test vector store loads correctly""" | |
| self.assertIsNotNone(self.vector_store.index) | |
| self.assertGreater(self.vector_store.index.ntotal, 0) | |
| self.assertEqual(len(self.vector_store.documents), len(self.vector_store.metadata)) | |
| def test_basic_search_functionality(self): | |
| """Test basic search returns relevant results""" | |
| query = "magnesium sulfate dosage for preeclampsia" | |
| results = self.vector_store.search(query, k=3) | |
| # Should return results | |
| self.assertGreater(len(results), 0) | |
| self.assertLessEqual(len(results), 3) | |
| # All results should be SearchResult objects | |
| for result in results: | |
| self.assertIsInstance(result, SearchResult) | |
| self.assertGreater(result.score, 0) | |
| self.assertIn('magnesium', result.content.lower()) | |
| def test_medical_context_filtering(self): | |
| """Test filtering by medical content types""" | |
| query = "emergency management protocols" | |
| # Test filtering by emergency content | |
| emergency_results = self.vector_store.search_by_medical_context( | |
| query, | |
| content_types=['emergency'], | |
| min_importance=0.8, | |
| k=5 | |
| ) | |
| # Should return emergency-specific results | |
| for result in emergency_results: | |
| self.assertEqual(result.chunk_type, 'emergency') | |
| self.assertGreaterEqual(result.clinical_importance, 0.8) | |
| def test_clinical_importance_filtering(self): | |
| """Test filtering by clinical importance""" | |
| query = "dosage recommendations" | |
| # Test high importance filtering | |
| high_importance_results = self.vector_store.search_by_medical_context( | |
| query, | |
| min_importance=0.9, | |
| k=10 | |
| ) | |
| # All results should have high clinical importance | |
| for result in high_importance_results: | |
| self.assertGreaterEqual(result.clinical_importance, 0.9) | |
| def test_search_performance(self): | |
| """Test search performance is acceptable""" | |
| query = "normal labor management guidelines" | |
| start_time = time.time() | |
| results = self.vector_store.search(query, k=5) | |
| search_time = time.time() - start_time | |
| # Search should be fast (under 1 second) | |
| self.assertLess(search_time, 1.0) | |
| self.assertGreater(len(results), 0) | |
| def test_maternal_health_queries(self): | |
| """Test specific maternal health queries return relevant results""" | |
| test_cases = [ | |
| { | |
| 'query': 'postpartum hemorrhage management', | |
| 'expected_keywords': ['hemorrhage', 'postpartum', 'bleeding'], | |
| 'min_score': 0.3 | |
| }, | |
| { | |
| 'query': 'fetal heart rate monitoring', | |
| 'expected_keywords': ['fetal', 'heart', 'rate', 'monitoring'], | |
| 'min_score': 0.3 | |
| }, | |
| { | |
| 'query': 'preeclampsia treatment protocols', | |
| 'expected_keywords': ['preeclampsia', 'treatment', 'protocol'], | |
| 'min_score': 0.3 | |
| } | |
| ] | |
| for case in test_cases: | |
| with self.subTest(query=case['query']): | |
| results = self.vector_store.search(case['query'], k=3) | |
| # Should return results | |
| self.assertGreater(len(results), 0) | |
| # Check relevance | |
| best_result = results[0] | |
| self.assertGreaterEqual(best_result.score, case['min_score']) | |
| # Check if keywords appear in results | |
| combined_content = ' '.join([r.content.lower() for r in results]) | |
| keyword_found = any( | |
| keyword in combined_content | |
| for keyword in case['expected_keywords'] | |
| ) | |
| self.assertTrue(keyword_found, | |
| f"No keywords {case['expected_keywords']} found in results") | |
| def test_statistics_functionality(self): | |
| """Test vector store statistics are accurate""" | |
| stats = self.vector_store.get_statistics() | |
| # Check required fields | |
| required_fields = [ | |
| 'total_chunks', 'embedding_dimension', 'embedding_model', | |
| 'chunk_type_distribution', 'clinical_importance_distribution' | |
| ] | |
| for field in required_fields: | |
| self.assertIn(field, stats) | |
| # Check values make sense | |
| self.assertGreater(stats['total_chunks'], 0) | |
| self.assertEqual(stats['embedding_dimension'], 384) | |
| self.assertIn('all-MiniLM-L6-v2', stats['embedding_model']) | |
| def test_dosage_information_retrieval(self): | |
| """Test retrieval of dosage-specific information""" | |
| dosage_queries = [ | |
| { | |
| 'query': "oxytocin dosage for labor induction", | |
| 'content_types': ['dosage', 'emergency', 'maternal', 'procedure'], # Include maternal and procedure | |
| 'dosage_terms': ['oxytocin', 'administration', 'dose', 'mg', 'ml', 'unit', 'continuous'] | |
| }, | |
| { | |
| 'query': "antibiotic prophylaxis dosing", | |
| 'content_types': ['dosage', 'emergency'], | |
| 'dosage_terms': ['mg', 'ml', 'dose', 'dosage', 'antibiotic', 'prophylaxis'] | |
| }, | |
| { | |
| 'query': "magnesium sulfate administration", | |
| 'content_types': ['dosage', 'emergency'], | |
| 'dosage_terms': ['magnesium', 'sulfate', 'mg', 'dose', 'administration'] | |
| } | |
| ] | |
| for case in dosage_queries: | |
| with self.subTest(query=case['query']): | |
| results = self.vector_store.search_by_medical_context( | |
| case['query'], | |
| content_types=case['content_types'], | |
| k=3 | |
| ) | |
| # Should find dosage-related content | |
| self.assertGreater(len(results), 0) | |
| # Check for dosage-related terms | |
| combined_content = ' '.join([r.content.lower() for r in results]) | |
| term_found = any(term in combined_content for term in case['dosage_terms']) | |
| self.assertTrue(term_found, | |
| f"No dosage terms {case['dosage_terms']} found for query: {case['query']}") | |
| def test_edge_cases(self): | |
| """Test edge cases and error handling""" | |
| # Empty query | |
| results = self.vector_store.search("", k=1) | |
| self.assertIsInstance(results, list) | |
| # Very specific query that might not match well | |
| results = self.vector_store.search("xyz unknown medical term", k=1) | |
| self.assertIsInstance(results, list) | |
| # Large k value | |
| results = self.vector_store.search("pregnancy", k=100) | |
| self.assertLessEqual(len(results), 100) | |
| def run_comprehensive_tests(): | |
| """Run all tests and provide detailed report""" | |
| print("π§ͺ Running Comprehensive Vector Store Tests...") | |
| print("=" * 60) | |
| # Create test suite | |
| loader = unittest.TestLoader() | |
| suite = loader.loadTestsFromTestCase(TestMaternalHealthVectorStore) | |
| # Run tests with detailed output | |
| runner = unittest.TextTestRunner(verbosity=2) | |
| result = runner.run(suite) | |
| # Print summary | |
| print("\n" + "=" * 60) | |
| print("π TEST SUMMARY:") | |
| print(f" Tests run: {result.testsRun}") | |
| print(f" Failures: {len(result.failures)}") | |
| print(f" Errors: {len(result.errors)}") | |
| if result.wasSuccessful(): | |
| print("β ALL TESTS PASSED! Vector store is working perfectly.") | |
| else: | |
| print("β Some tests failed. Check output above for details.") | |
| if result.failures: | |
| print("\nFailures:") | |
| for test, traceback in result.failures: | |
| # Extract the last meaningful line from traceback | |
| lines = traceback.strip().split('\n') | |
| error_line = lines[-1] if lines else "Unknown failure" | |
| print(f" - {test}: {error_line}") | |
| if result.errors: | |
| print("\nErrors:") | |
| for test, traceback in result.errors: | |
| # Extract the last meaningful line from traceback | |
| lines = traceback.strip().split('\n') | |
| error_line = lines[-1] if lines else "Unknown error" | |
| print(f" - {test}: {error_line}") | |
| return result.wasSuccessful() | |
| if __name__ == "__main__": | |
| success = run_comprehensive_tests() | |
| exit(0 if success else 1) |