Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Test Suite for Comprehensive Medical Document Chunking | |
| Validates clinical context preservation and chunk quality | |
| """ | |
| import json | |
| import pytest | |
| from pathlib import Path | |
| from typing import Dict, List, Any | |
| import logging | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class ChunkingQualityValidator: | |
| """Validates the quality of medical document chunking""" | |
| def __init__(self, chunks_dir: Path = Path("comprehensive_chunks")): | |
| self.chunks_dir = chunks_dir | |
| self.test_results = {} | |
| def load_chunking_report(self) -> Dict[str, Any]: | |
| """Load the comprehensive chunking report""" | |
| report_file = self.chunks_dir / "comprehensive_chunking_report.json" | |
| if not report_file.exists(): | |
| raise FileNotFoundError(f"Chunking report not found: {report_file}") | |
| with open(report_file) as f: | |
| return json.load(f) | |
| def load_sample_chunks(self, doc_name: str, limit: int = 5) -> List[Dict]: | |
| """Load sample chunks from a document""" | |
| doc_chunks_file = self.chunks_dir / doc_name / "comprehensive_chunks.json" | |
| if not doc_chunks_file.exists(): | |
| return [] | |
| with open(doc_chunks_file) as f: | |
| chunks = json.load(f) | |
| return chunks[:limit] | |
| def test_basic_statistics(self, report: Dict[str, Any]) -> bool: | |
| """Test basic chunking statistics""" | |
| logger.info("Testing basic chunking statistics...") | |
| try: | |
| # Test that we have reasonable number of chunks | |
| total_chunks = report['total_chunks'] | |
| total_docs = report['total_documents'] | |
| assert total_chunks > 0, "No chunks were created" | |
| assert total_docs > 0, "No documents were processed" | |
| assert total_chunks >= total_docs, "Too few chunks per document" | |
| # Test chunk distribution | |
| chunk_types = report['chunk_type_distribution'] | |
| assert len(chunk_types) > 0, "No chunk types identified" | |
| # Test importance distribution | |
| importance_dist = report['clinical_importance_distribution'] | |
| high_importance = importance_dist.get('critical', 0) + importance_dist.get('high', 0) | |
| assert high_importance > 0, "No high importance chunks found" | |
| logger.info(f"β Basic statistics: {total_chunks} chunks from {total_docs} documents") | |
| self.test_results['basic_statistics'] = True | |
| return True | |
| except AssertionError as e: | |
| logger.error(f"β Basic statistics test failed: {e}") | |
| self.test_results['basic_statistics'] = False | |
| return False | |
| def test_clinical_content_recognition(self, report: Dict[str, Any]) -> bool: | |
| """Test that clinical content is properly recognized""" | |
| logger.info("Testing clinical content recognition...") | |
| try: | |
| processing_summary = report['processing_summary'] | |
| # Test for maternal health content | |
| maternal_chunks = processing_summary.get('maternal_chunks', 0) | |
| assert maternal_chunks > 0, "No maternal health content identified" | |
| # Test for dosage information | |
| dosage_chunks = processing_summary.get('dosage_chunks', 0) | |
| assert dosage_chunks > 0, "No dosage information identified" | |
| # Test for emergency content | |
| emergency_chunks = processing_summary.get('emergency_chunks', 0) | |
| # Emergency content is optional but good to have | |
| # Test for table preservation | |
| table_chunks = processing_summary.get('chunks_with_tables', 0) | |
| assert table_chunks > 0, "No table content preserved" | |
| logger.info(f"β Clinical content: {maternal_chunks} maternal, {dosage_chunks} dosage, {table_chunks} with tables") | |
| self.test_results['clinical_content'] = True | |
| return True | |
| except AssertionError as e: | |
| logger.error(f"β Clinical content test failed: {e}") | |
| self.test_results['clinical_content'] = False | |
| return False | |
| def test_chunk_quality(self, report: Dict[str, Any]) -> bool: | |
| """Test individual chunk quality""" | |
| logger.info("Testing chunk quality...") | |
| try: | |
| # Load sample chunks from different documents | |
| doc_names = list(report['document_statistics'].keys()) | |
| sample_count = 0 | |
| valid_chunks = 0 | |
| for doc_name in doc_names[:3]: # Test first 3 documents | |
| chunks = self.load_sample_chunks(doc_name, limit=3) | |
| for chunk in chunks: | |
| sample_count += 1 | |
| # Test chunk structure | |
| required_fields = ['content', 'chunk_type', 'clinical_importance', 'medical_context'] | |
| if all(field in chunk for field in required_fields): | |
| valid_chunks += 1 | |
| # Test content quality | |
| content = chunk['content'] | |
| if len(content.strip()) > 50: # Reasonable content length | |
| # Test clinical importance scoring | |
| importance = chunk['clinical_importance'] | |
| if 0 <= importance <= 1: | |
| # Test medical context | |
| context = chunk['medical_context'] | |
| if isinstance(context, dict) and len(context) > 0: | |
| continue | |
| chunk_quality_ratio = valid_chunks / sample_count if sample_count > 0 else 0 | |
| assert chunk_quality_ratio >= 0.8, f"Chunk quality too low: {chunk_quality_ratio:.2f}" | |
| logger.info(f"β Chunk quality: {valid_chunks}/{sample_count} chunks passed quality checks") | |
| self.test_results['chunk_quality'] = True | |
| return True | |
| except AssertionError as e: | |
| logger.error(f"β Chunk quality test failed: {e}") | |
| self.test_results['chunk_quality'] = False | |
| return False | |
| except Exception as e: | |
| logger.error(f"β Chunk quality test error: {e}") | |
| self.test_results['chunk_quality'] = False | |
| return False | |
| def test_medical_context_preservation(self) -> bool: | |
| """Test that medical context is properly preserved""" | |
| logger.info("Testing medical context preservation...") | |
| try: | |
| # Load LangChain documents | |
| langchain_file = self.chunks_dir / "langchain_documents_comprehensive.json" | |
| if not langchain_file.exists(): | |
| raise FileNotFoundError("LangChain documents not found") | |
| with open(langchain_file) as f: | |
| langchain_docs = json.load(f) | |
| # Test sample of documents | |
| medical_context_count = 0 | |
| total_tested = 0 | |
| for doc in langchain_docs[:20]: # Test first 20 documents | |
| total_tested += 1 | |
| metadata = doc.get('metadata', {}) | |
| # Check for medical context fields | |
| medical_fields = [ | |
| 'chunk_type', 'clinical_importance', 'keywords', | |
| 'has_clinical_protocols', 'has_dosage_info', 'is_maternal_specific' | |
| ] | |
| if any(field in metadata for field in medical_fields): | |
| medical_context_count += 1 | |
| context_ratio = medical_context_count / total_tested if total_tested > 0 else 0 | |
| assert context_ratio >= 0.8, f"Medical context preservation too low: {context_ratio:.2f}" | |
| logger.info(f"β Medical context: {medical_context_count}/{total_tested} documents have medical context") | |
| self.test_results['medical_context'] = True | |
| return True | |
| except AssertionError as e: | |
| logger.error(f"β Medical context test failed: {e}") | |
| self.test_results['medical_context'] = False | |
| return False | |
| except Exception as e: | |
| logger.error(f"β Medical context test error: {e}") | |
| self.test_results['medical_context'] = False | |
| return False | |
| def test_document_coverage(self, report: Dict[str, Any]) -> bool: | |
| """Test that all documents were processed""" | |
| logger.info("Testing document coverage...") | |
| try: | |
| doc_stats = report['document_statistics'] | |
| processed_docs = len(doc_stats) | |
| # We should have processed all 15 maternal health documents | |
| expected_min_docs = 10 # Minimum expected | |
| assert processed_docs >= expected_min_docs, f"Too few documents processed: {processed_docs}" | |
| # Check that each document has reasonable chunks | |
| docs_with_good_coverage = 0 | |
| for doc_name, stats in doc_stats.items(): | |
| if stats['total_chunks'] > 0: | |
| docs_with_good_coverage += 1 | |
| coverage_ratio = docs_with_good_coverage / processed_docs | |
| assert coverage_ratio >= 0.9, f"Document coverage too low: {coverage_ratio:.2f}" | |
| logger.info(f"β Document coverage: {docs_with_good_coverage}/{processed_docs} documents well covered") | |
| self.test_results['document_coverage'] = True | |
| return True | |
| except AssertionError as e: | |
| logger.error(f"β Document coverage test failed: {e}") | |
| self.test_results['document_coverage'] = False | |
| return False | |
| def test_clinical_importance_distribution(self, report: Dict[str, Any]) -> bool: | |
| """Test that clinical importance is properly distributed""" | |
| logger.info("Testing clinical importance distribution...") | |
| try: | |
| importance_dist = report['clinical_importance_distribution'] | |
| total = sum(importance_dist.values()) | |
| critical_ratio = importance_dist.get('critical', 0) / total | |
| high_ratio = importance_dist.get('high', 0) / total | |
| # We expect a good amount of high-importance content for medical guidelines | |
| high_importance_ratio = critical_ratio + high_ratio | |
| assert high_importance_ratio >= 0.3, f"Too little high-importance content: {high_importance_ratio:.2f}" | |
| logger.info(f"β Clinical importance: {high_importance_ratio:.1%} high-importance chunks") | |
| self.test_results['clinical_importance'] = True | |
| return True | |
| except AssertionError as e: | |
| logger.error(f"β Clinical importance test failed: {e}") | |
| self.test_results['clinical_importance'] = False | |
| return False | |
| def run_all_tests(self) -> Dict[str, bool]: | |
| """Run all quality validation tests""" | |
| logger.info("=" * 80) | |
| logger.info("STARTING COMPREHENSIVE CHUNKING QUALITY VALIDATION") | |
| logger.info("=" * 80) | |
| try: | |
| # Load the chunking report | |
| report = self.load_chunking_report() | |
| # Run all tests | |
| tests = [ | |
| ('Basic Statistics', lambda: self.test_basic_statistics(report)), | |
| ('Clinical Content Recognition', lambda: self.test_clinical_content_recognition(report)), | |
| ('Chunk Quality', lambda: self.test_chunk_quality(report)), | |
| ('Medical Context Preservation', lambda: self.test_medical_context_preservation()), | |
| ('Document Coverage', lambda: self.test_document_coverage(report)), | |
| ('Clinical Importance Distribution', lambda: self.test_clinical_importance_distribution(report)) | |
| ] | |
| results = {} | |
| passed_tests = 0 | |
| for test_name, test_func in tests: | |
| logger.info(f"\nπ§ͺ Running: {test_name}") | |
| try: | |
| result = test_func() | |
| results[test_name] = result | |
| if result: | |
| passed_tests += 1 | |
| except Exception as e: | |
| logger.error(f"β {test_name} failed with error: {e}") | |
| results[test_name] = False | |
| # Summary | |
| logger.info("\n" + "=" * 80) | |
| logger.info("CHUNKING QUALITY VALIDATION SUMMARY") | |
| logger.info("=" * 80) | |
| logger.info(f"β Tests Passed: {passed_tests}/{len(tests)}") | |
| for test_name, result in results.items(): | |
| status = "β PASS" if result else "β FAIL" | |
| logger.info(f"{status}: {test_name}") | |
| overall_success = passed_tests >= (len(tests) * 0.8) # 80% pass rate | |
| if overall_success: | |
| logger.info("\nπ OVERALL RESULT: CHUNKING QUALITY VALIDATION PASSED!") | |
| else: | |
| logger.info("\nβ οΈ OVERALL RESULT: CHUNKING QUALITY VALIDATION NEEDS IMPROVEMENT") | |
| logger.info("=" * 80) | |
| return results | |
| except Exception as e: | |
| logger.error(f"β Validation failed with error: {e}") | |
| return {} | |
| def main(): | |
| """Main test function""" | |
| validator = ChunkingQualityValidator() | |
| results = validator.run_all_tests() | |
| # Save test results | |
| test_results_file = Path("comprehensive_chunks") / "quality_validation_results.json" | |
| with open(test_results_file, "w") as f: | |
| json.dump({ | |
| 'test_results': results, | |
| 'summary': { | |
| 'total_tests': len(results), | |
| 'passed_tests': sum(results.values()), | |
| 'pass_rate': sum(results.values()) / len(results) if results else 0 | |
| } | |
| }, f, indent=2) | |
| logger.info(f"π Test results saved to: {test_results_file}") | |
| if __name__ == "__main__": | |
| main() |