Spaces:
Paused
Paused
| """Tests for RAG (Retrieval Augmented Generation) components.""" | |
| import pytest | |
| from pathlib import Path | |
| import tempfile | |
| import shutil | |
| from rag.chunker import SemanticChunker, Chunk, chunk_file | |
| from rag.vectorstore import ChromaVectorStore, MockEmbeddingFunction | |
| from rag.retriever import FDAMRetriever, MockReranker, RetrievalResult | |
| class TestSemanticChunker: | |
| """Test semantic chunker with table preservation.""" | |
| def test_chunk_simple_document(self): | |
| """Test chunking a simple markdown document.""" | |
| text = """## Introduction | |
| This is the introduction paragraph with some content. | |
| ## Section One | |
| This section contains important information about the topic. | |
| It has multiple sentences to form a proper paragraph. | |
| ## Section Two | |
| Another section with different content here. | |
| """ | |
| chunker = SemanticChunker() | |
| chunks = chunker.chunk_document( | |
| text=text, | |
| source="test.md", | |
| category="methodology", | |
| priority="primary", | |
| ) | |
| assert len(chunks) >= 1 | |
| assert all(isinstance(c, Chunk) for c in chunks) | |
| assert all(c.source == "test.md" for c in chunks) | |
| assert all(c.category == "methodology" for c in chunks) | |
| assert all(c.priority == "primary" for c in chunks) | |
| def test_preserve_tables(self): | |
| """Test that tables are kept intact and not split.""" | |
| text = """## Thresholds | |
| | Material | Threshold | Unit | | |
| |----------|-----------|------| | |
| | Lead | 22 | µg/100cm² | | |
| | Cadmium | 3.3 | µg/100cm² | | |
| | Arsenic | 6.7 | µg/100cm² | | |
| ## Next Section | |
| Some content after the table. | |
| """ | |
| chunker = SemanticChunker() | |
| chunks = chunker.chunk_document( | |
| text=text, | |
| source="thresholds.md", | |
| category="thresholds", | |
| priority="reference-threshold", | |
| ) | |
| # Find the table chunk | |
| table_chunks = [c for c in chunks if c.content_type == "table"] | |
| assert len(table_chunks) >= 1 | |
| # Table should be complete | |
| table_chunk = table_chunks[0] | |
| assert "Lead" in table_chunk.text | |
| assert "Cadmium" in table_chunk.text | |
| assert "Arsenic" in table_chunk.text | |
| assert "|" in table_chunk.text | |
| def test_extract_keywords(self): | |
| """Test keyword extraction from text.""" | |
| text = """## Zone Classification | |
| The burn zone shows heavy soot deposits and structural damage. | |
| Lead contamination requires HEPA vacuum cleaning per OSHA standards. | |
| """ | |
| chunker = SemanticChunker() | |
| chunks = chunker.chunk_document( | |
| text=text, | |
| source="zones.md", | |
| category="methodology", | |
| priority="primary", | |
| ) | |
| # Should extract relevant keywords | |
| all_keywords = [] | |
| for chunk in chunks: | |
| all_keywords.extend(chunk.keywords) | |
| # Check for expected domain keywords | |
| keyword_set = set(all_keywords) | |
| assert "burn zone" in keyword_set or "heavy" in keyword_set | |
| assert "soot" in keyword_set or "structural damage" in keyword_set | |
| def test_chunk_metadata(self): | |
| """Test chunk metadata conversion.""" | |
| chunk = Chunk( | |
| id="test_001", | |
| text="Test content", | |
| source="test.md", | |
| category="methodology", | |
| section="## Section 1", | |
| priority="primary", | |
| content_type="narrative", | |
| keywords=["lead", "soot"], | |
| ) | |
| metadata = chunk.to_metadata() | |
| assert metadata["source"] == "test.md" | |
| assert metadata["category"] == "methodology" | |
| assert metadata["priority"] == "primary" | |
| assert metadata["content_type"] == "narrative" | |
| assert "lead" in metadata["keywords"] | |
| assert "soot" in metadata["keywords"] | |
| def test_split_by_headers(self): | |
| """Test section splitting by markdown headers.""" | |
| text = """## Section One | |
| Content one. | |
| ### Subsection A | |
| Content A. | |
| ## Section Two | |
| Content two. | |
| """ | |
| chunker = SemanticChunker() | |
| sections = chunker._split_by_headers(text) | |
| # Should have at least 3 sections (Introduction + 2 main + 1 sub) | |
| assert len(sections) >= 2 | |
| headers = [s[0] for s in sections] | |
| assert any("Section One" in h for h in headers) | |
| assert any("Section Two" in h for h in headers) | |
| class TestMockEmbeddingFunction: | |
| """Test mock embedding function.""" | |
| def test_embedding_dimension(self): | |
| """Test that embeddings have correct dimension.""" | |
| mock = MockEmbeddingFunction() | |
| embeddings = mock(["test text"]) | |
| assert len(embeddings) == 1 | |
| assert len(embeddings[0]) == mock.EMBEDDING_DIM | |
| def test_deterministic_embeddings(self): | |
| """Test that same text produces same embedding.""" | |
| mock = MockEmbeddingFunction() | |
| text = "This is a test sentence." | |
| emb1 = mock([text])[0] | |
| emb2 = mock([text])[0] | |
| assert emb1 == emb2 | |
| def test_different_texts_different_embeddings(self): | |
| """Test that different texts produce different embeddings.""" | |
| mock = MockEmbeddingFunction() | |
| emb1 = mock(["First text"])[0] | |
| emb2 = mock(["Second text"])[0] | |
| assert emb1 != emb2 | |
| def test_batch_embeddings(self): | |
| """Test embedding multiple texts at once.""" | |
| mock = MockEmbeddingFunction() | |
| texts = ["Text one", "Text two", "Text three"] | |
| embeddings = mock(texts) | |
| assert len(embeddings) == 3 | |
| assert all(len(e) == mock.EMBEDDING_DIM for e in embeddings) | |
| class TestChromaVectorStore: | |
| """Test ChromaDB vector store.""" | |
| def temp_dir(self): | |
| """Create a temporary directory for ChromaDB.""" | |
| temp = tempfile.mkdtemp() | |
| yield temp | |
| shutil.rmtree(temp) | |
| def vectorstore(self, temp_dir): | |
| """Create a test vector store.""" | |
| return ChromaVectorStore( | |
| persist_directory=temp_dir, | |
| embedding_function=MockEmbeddingFunction(), | |
| ) | |
| def sample_chunks(self): | |
| """Create sample chunks for testing.""" | |
| return [ | |
| Chunk( | |
| id="chunk_001", | |
| text="Lead threshold for non-operational facilities is 22 µg/100cm².", | |
| source="fdam.md", | |
| category="thresholds", | |
| section="## 1.4 Thresholds", | |
| priority="primary", | |
| content_type="narrative", | |
| keywords=["lead", "non-operational"], | |
| ), | |
| Chunk( | |
| id="chunk_002", | |
| text="Burn zone requires structural assessment before cleaning.", | |
| source="fdam.md", | |
| category="methodology", | |
| section="## 4.1 Zone Classification", | |
| priority="primary", | |
| content_type="narrative", | |
| keywords=["burn zone", "structural damage"], | |
| ), | |
| Chunk( | |
| id="chunk_003", | |
| text="HEPA vacuum is required for soot removal.", | |
| source="cleaning.md", | |
| category="cleaning-procedures", | |
| section="## 3.2 Methods", | |
| priority="reference-narrative", | |
| content_type="narrative", | |
| keywords=["hepa", "vacuum", "soot"], | |
| ), | |
| ] | |
| def test_add_chunks(self, vectorstore, sample_chunks): | |
| """Test adding chunks to vector store.""" | |
| count = vectorstore.add_chunks(sample_chunks) | |
| assert count == 3 | |
| stats = vectorstore.get_stats() | |
| assert stats["total_chunks"] == 3 | |
| def test_query_returns_results(self, vectorstore, sample_chunks): | |
| """Test querying the vector store.""" | |
| vectorstore.add_chunks(sample_chunks) | |
| results = vectorstore.query("lead threshold", n_results=2) | |
| assert len(results) <= 2 | |
| assert all("id" in r for r in results) | |
| assert all("document" in r for r in results) | |
| assert all("metadata" in r for r in results) | |
| assert all("distance" in r for r in results) | |
| def test_query_with_metadata_filter(self, vectorstore, sample_chunks): | |
| """Test querying with metadata filter.""" | |
| vectorstore.add_chunks(sample_chunks) | |
| results = vectorstore.query( | |
| "cleaning method", | |
| n_results=5, | |
| where={"priority": "primary"}, | |
| ) | |
| # All results should have primary priority | |
| for r in results: | |
| assert r["metadata"]["priority"] == "primary" | |
| def test_clear_collection(self, vectorstore, sample_chunks): | |
| """Test clearing the collection.""" | |
| vectorstore.add_chunks(sample_chunks) | |
| assert vectorstore.get_stats()["total_chunks"] == 3 | |
| vectorstore.clear() | |
| assert vectorstore.get_stats()["total_chunks"] == 0 | |
| def test_delete_by_source(self, vectorstore, sample_chunks): | |
| """Test deleting chunks by source.""" | |
| vectorstore.add_chunks(sample_chunks) | |
| deleted = vectorstore.delete_by_source("fdam.md") | |
| assert deleted == 2 # Two chunks from fdam.md | |
| stats = vectorstore.get_stats() | |
| assert stats["total_chunks"] == 1 | |
| def test_get_stats(self, vectorstore, sample_chunks): | |
| """Test getting collection statistics.""" | |
| vectorstore.add_chunks(sample_chunks) | |
| stats = vectorstore.get_stats() | |
| assert stats["total_chunks"] == 3 | |
| assert "thresholds" in stats["categories"] | |
| assert "methodology" in stats["categories"] | |
| assert "primary" in stats["priorities"] | |
| assert "reference-narrative" in stats["priorities"] | |
| class TestMockReranker: | |
| """Test mock reranker.""" | |
| def test_rerank_returns_scores(self): | |
| """Test that reranker returns scores.""" | |
| reranker = MockReranker() | |
| query = "lead threshold contamination" | |
| documents = [ | |
| "Lead threshold for facilities is 22 µg/100cm².", | |
| "The weather is nice today.", | |
| "Contamination levels require assessment.", | |
| ] | |
| scores = reranker.rerank(query, documents) | |
| assert len(scores) == 3 | |
| assert all(0 <= s <= 1 for s in scores) | |
| def test_relevant_doc_higher_score(self): | |
| """Test that more relevant docs get higher scores.""" | |
| reranker = MockReranker() | |
| query = "lead threshold" | |
| documents = [ | |
| "Lead threshold is 22 µg.", # Very relevant | |
| "Weather forecast for tomorrow.", # Not relevant | |
| ] | |
| scores = reranker.rerank(query, documents) | |
| # First doc should have higher score (shares more words) | |
| assert scores[0] > scores[1] | |
| class TestFDAMRetriever: | |
| """Test FDAM retriever with priority weighting.""" | |
| def temp_dir(self): | |
| """Create a temporary directory.""" | |
| temp = tempfile.mkdtemp() | |
| yield temp | |
| shutil.rmtree(temp) | |
| def retriever(self, temp_dir): | |
| """Create a test retriever with sample data.""" | |
| vectorstore = ChromaVectorStore( | |
| persist_directory=temp_dir, | |
| embedding_function=MockEmbeddingFunction(), | |
| ) | |
| # Add sample chunks | |
| chunks = [ | |
| Chunk( | |
| id="primary_001", | |
| text="Lead threshold for non-operational is 22 µg/100cm² per FDAM.", | |
| source="fdam.md", | |
| category="thresholds", | |
| section="## Thresholds", | |
| priority="primary", | |
| content_type="narrative", | |
| keywords=["lead", "threshold", "non-operational"], | |
| ), | |
| Chunk( | |
| id="ref_001", | |
| text="Lead clearance levels from BNL SOP.", | |
| source="bnl.md", | |
| category="thresholds", | |
| section="## Attachment 9.3", | |
| priority="reference-threshold", | |
| content_type="table", | |
| keywords=["lead", "clearance"], | |
| ), | |
| Chunk( | |
| id="ref_002", | |
| text="General cleaning procedures for soot removal.", | |
| source="cleaning.md", | |
| category="cleaning-procedures", | |
| section="## Methods", | |
| priority="reference-narrative", | |
| content_type="narrative", | |
| keywords=["cleaning", "soot"], | |
| ), | |
| ] | |
| vectorstore.add_chunks(chunks) | |
| return FDAMRetriever( | |
| vectorstore=vectorstore, | |
| reranker=MockReranker(), | |
| use_reranking=True, | |
| ) | |
| def test_retrieve_returns_results(self, retriever): | |
| """Test basic retrieval.""" | |
| results = retriever.retrieve("lead threshold", top_k=3) | |
| assert len(results) <= 3 | |
| assert all(isinstance(r, RetrievalResult) for r in results) | |
| def test_priority_weighting(self, retriever): | |
| """Test that primary sources get higher weight.""" | |
| results = retriever.retrieve("lead threshold", top_k=3) | |
| # Find primary and reference results | |
| primary_results = [r for r in results if r.priority == "primary"] | |
| ref_results = [r for r in results if r.priority != "primary"] | |
| if primary_results and ref_results: | |
| # Primary should have higher weighted score (before reranking) | |
| # Note: final_score includes reranking which may change order | |
| primary = primary_results[0] | |
| ref = ref_results[0] | |
| # With similar similarity, primary weight (1.0) > ref weight (0.8-0.9) | |
| # This test validates the weighting is applied | |
| assert primary.weighted_score > 0 | |
| def test_category_filter(self, retriever): | |
| """Test filtering by category.""" | |
| results = retriever.retrieve( | |
| "cleaning method", | |
| top_k=5, | |
| category_filter="cleaning-procedures", | |
| ) | |
| for r in results: | |
| assert r.category == "cleaning-procedures" | |
| def test_priority_filter(self, retriever): | |
| """Test filtering by priority.""" | |
| results = retriever.retrieve( | |
| "threshold", | |
| top_k=5, | |
| priority_filter="primary", | |
| ) | |
| for r in results: | |
| assert r.priority == "primary" | |
| def test_retrieve_for_context(self, retriever): | |
| """Test context string generation.""" | |
| context = retriever.retrieve_for_context("lead threshold", top_k=2) | |
| assert isinstance(context, str) | |
| assert "Source:" in context or "No relevant context" in context | |
| def test_retrieve_thresholds(self, retriever): | |
| """Test threshold-specific retrieval.""" | |
| results = retriever.retrieve_thresholds( | |
| material_type="lead", | |
| facility_type="non-operational", | |
| ) | |
| assert len(results) <= 3 | |
| # Should filter to thresholds category | |
| for r in results: | |
| assert r.category == "thresholds" | |
| def test_retrieve_disposition(self, retriever): | |
| """Test disposition-specific retrieval.""" | |
| results = retriever.retrieve_disposition( | |
| zone="burn-zone", | |
| condition="heavy", | |
| ) | |
| # Should prefer primary sources | |
| if results: | |
| assert results[0].priority == "primary" | |
| def test_result_to_dict(self, retriever): | |
| """Test RetrievalResult to_dict method.""" | |
| results = retriever.retrieve("test", top_k=1) | |
| if results: | |
| result_dict = results[0].to_dict() | |
| assert "chunk_id" in result_dict | |
| assert "text" in result_dict | |
| assert "source" in result_dict | |
| assert "similarity_score" in result_dict | |
| assert "final_score" in result_dict | |
| def test_empty_query_handling(self, retriever): | |
| """Test handling of query with no good matches.""" | |
| results = retriever.retrieve( | |
| "completely unrelated xyz123", | |
| top_k=5, | |
| category_filter="thresholds", | |
| ) | |
| # Should still return results (just lower scores) | |
| assert isinstance(results, list) | |
| class TestChunkFile: | |
| """Test the chunk_file convenience function.""" | |
| def temp_md_file(self): | |
| """Create a temporary markdown file.""" | |
| temp = tempfile.NamedTemporaryFile( | |
| mode="w", | |
| suffix=".md", | |
| delete=False, | |
| encoding="utf-8", | |
| ) | |
| temp.write("""## Test Document | |
| This is test content for chunking. | |
| | Column A | Column B | | |
| |----------|----------| | |
| | Value 1 | Value 2 | | |
| """) | |
| temp.close() | |
| yield Path(temp.name) | |
| Path(temp.name).unlink() | |
| def test_chunk_file(self, temp_md_file): | |
| """Test chunking a file directly.""" | |
| chunks = chunk_file( | |
| filepath=temp_md_file, | |
| category="methodology", | |
| priority="primary", | |
| ) | |
| assert len(chunks) >= 1 | |
| assert all(c.source == temp_md_file.name for c in chunks) | |
| assert all(c.category == "methodology" for c in chunks) | |