SmokeScan / tests /test_rag.py
KinetoLabs's picture
Initial commit: FDAM AI Pipeline v4.0.1
88bdcff
raw
history blame
17.2 kB
"""Tests for RAG (Retrieval Augmented Generation) components."""
import pytest
from pathlib import Path
import tempfile
import shutil
from rag.chunker import SemanticChunker, Chunk, chunk_file
from rag.vectorstore import ChromaVectorStore, MockEmbeddingFunction
from rag.retriever import FDAMRetriever, MockReranker, RetrievalResult
class TestSemanticChunker:
"""Test semantic chunker with table preservation."""
def test_chunk_simple_document(self):
"""Test chunking a simple markdown document."""
text = """## Introduction
This is the introduction paragraph with some content.
## Section One
This section contains important information about the topic.
It has multiple sentences to form a proper paragraph.
## Section Two
Another section with different content here.
"""
chunker = SemanticChunker()
chunks = chunker.chunk_document(
text=text,
source="test.md",
category="methodology",
priority="primary",
)
assert len(chunks) >= 1
assert all(isinstance(c, Chunk) for c in chunks)
assert all(c.source == "test.md" for c in chunks)
assert all(c.category == "methodology" for c in chunks)
assert all(c.priority == "primary" for c in chunks)
def test_preserve_tables(self):
"""Test that tables are kept intact and not split."""
text = """## Thresholds
| Material | Threshold | Unit |
|----------|-----------|------|
| Lead | 22 | µg/100cm² |
| Cadmium | 3.3 | µg/100cm² |
| Arsenic | 6.7 | µg/100cm² |
## Next Section
Some content after the table.
"""
chunker = SemanticChunker()
chunks = chunker.chunk_document(
text=text,
source="thresholds.md",
category="thresholds",
priority="reference-threshold",
)
# Find the table chunk
table_chunks = [c for c in chunks if c.content_type == "table"]
assert len(table_chunks) >= 1
# Table should be complete
table_chunk = table_chunks[0]
assert "Lead" in table_chunk.text
assert "Cadmium" in table_chunk.text
assert "Arsenic" in table_chunk.text
assert "|" in table_chunk.text
def test_extract_keywords(self):
"""Test keyword extraction from text."""
text = """## Zone Classification
The burn zone shows heavy soot deposits and structural damage.
Lead contamination requires HEPA vacuum cleaning per OSHA standards.
"""
chunker = SemanticChunker()
chunks = chunker.chunk_document(
text=text,
source="zones.md",
category="methodology",
priority="primary",
)
# Should extract relevant keywords
all_keywords = []
for chunk in chunks:
all_keywords.extend(chunk.keywords)
# Check for expected domain keywords
keyword_set = set(all_keywords)
assert "burn zone" in keyword_set or "heavy" in keyword_set
assert "soot" in keyword_set or "structural damage" in keyword_set
def test_chunk_metadata(self):
"""Test chunk metadata conversion."""
chunk = Chunk(
id="test_001",
text="Test content",
source="test.md",
category="methodology",
section="## Section 1",
priority="primary",
content_type="narrative",
keywords=["lead", "soot"],
)
metadata = chunk.to_metadata()
assert metadata["source"] == "test.md"
assert metadata["category"] == "methodology"
assert metadata["priority"] == "primary"
assert metadata["content_type"] == "narrative"
assert "lead" in metadata["keywords"]
assert "soot" in metadata["keywords"]
def test_split_by_headers(self):
"""Test section splitting by markdown headers."""
text = """## Section One
Content one.
### Subsection A
Content A.
## Section Two
Content two.
"""
chunker = SemanticChunker()
sections = chunker._split_by_headers(text)
# Should have at least 3 sections (Introduction + 2 main + 1 sub)
assert len(sections) >= 2
headers = [s[0] for s in sections]
assert any("Section One" in h for h in headers)
assert any("Section Two" in h for h in headers)
class TestMockEmbeddingFunction:
"""Test mock embedding function."""
def test_embedding_dimension(self):
"""Test that embeddings have correct dimension."""
mock = MockEmbeddingFunction()
embeddings = mock(["test text"])
assert len(embeddings) == 1
assert len(embeddings[0]) == mock.EMBEDDING_DIM
def test_deterministic_embeddings(self):
"""Test that same text produces same embedding."""
mock = MockEmbeddingFunction()
text = "This is a test sentence."
emb1 = mock([text])[0]
emb2 = mock([text])[0]
assert emb1 == emb2
def test_different_texts_different_embeddings(self):
"""Test that different texts produce different embeddings."""
mock = MockEmbeddingFunction()
emb1 = mock(["First text"])[0]
emb2 = mock(["Second text"])[0]
assert emb1 != emb2
def test_batch_embeddings(self):
"""Test embedding multiple texts at once."""
mock = MockEmbeddingFunction()
texts = ["Text one", "Text two", "Text three"]
embeddings = mock(texts)
assert len(embeddings) == 3
assert all(len(e) == mock.EMBEDDING_DIM for e in embeddings)
class TestChromaVectorStore:
"""Test ChromaDB vector store."""
@pytest.fixture
def temp_dir(self):
"""Create a temporary directory for ChromaDB."""
temp = tempfile.mkdtemp()
yield temp
shutil.rmtree(temp)
@pytest.fixture
def vectorstore(self, temp_dir):
"""Create a test vector store."""
return ChromaVectorStore(
persist_directory=temp_dir,
embedding_function=MockEmbeddingFunction(),
)
@pytest.fixture
def sample_chunks(self):
"""Create sample chunks for testing."""
return [
Chunk(
id="chunk_001",
text="Lead threshold for non-operational facilities is 22 µg/100cm².",
source="fdam.md",
category="thresholds",
section="## 1.4 Thresholds",
priority="primary",
content_type="narrative",
keywords=["lead", "non-operational"],
),
Chunk(
id="chunk_002",
text="Burn zone requires structural assessment before cleaning.",
source="fdam.md",
category="methodology",
section="## 4.1 Zone Classification",
priority="primary",
content_type="narrative",
keywords=["burn zone", "structural damage"],
),
Chunk(
id="chunk_003",
text="HEPA vacuum is required for soot removal.",
source="cleaning.md",
category="cleaning-procedures",
section="## 3.2 Methods",
priority="reference-narrative",
content_type="narrative",
keywords=["hepa", "vacuum", "soot"],
),
]
def test_add_chunks(self, vectorstore, sample_chunks):
"""Test adding chunks to vector store."""
count = vectorstore.add_chunks(sample_chunks)
assert count == 3
stats = vectorstore.get_stats()
assert stats["total_chunks"] == 3
def test_query_returns_results(self, vectorstore, sample_chunks):
"""Test querying the vector store."""
vectorstore.add_chunks(sample_chunks)
results = vectorstore.query("lead threshold", n_results=2)
assert len(results) <= 2
assert all("id" in r for r in results)
assert all("document" in r for r in results)
assert all("metadata" in r for r in results)
assert all("distance" in r for r in results)
def test_query_with_metadata_filter(self, vectorstore, sample_chunks):
"""Test querying with metadata filter."""
vectorstore.add_chunks(sample_chunks)
results = vectorstore.query(
"cleaning method",
n_results=5,
where={"priority": "primary"},
)
# All results should have primary priority
for r in results:
assert r["metadata"]["priority"] == "primary"
def test_clear_collection(self, vectorstore, sample_chunks):
"""Test clearing the collection."""
vectorstore.add_chunks(sample_chunks)
assert vectorstore.get_stats()["total_chunks"] == 3
vectorstore.clear()
assert vectorstore.get_stats()["total_chunks"] == 0
def test_delete_by_source(self, vectorstore, sample_chunks):
"""Test deleting chunks by source."""
vectorstore.add_chunks(sample_chunks)
deleted = vectorstore.delete_by_source("fdam.md")
assert deleted == 2 # Two chunks from fdam.md
stats = vectorstore.get_stats()
assert stats["total_chunks"] == 1
def test_get_stats(self, vectorstore, sample_chunks):
"""Test getting collection statistics."""
vectorstore.add_chunks(sample_chunks)
stats = vectorstore.get_stats()
assert stats["total_chunks"] == 3
assert "thresholds" in stats["categories"]
assert "methodology" in stats["categories"]
assert "primary" in stats["priorities"]
assert "reference-narrative" in stats["priorities"]
class TestMockReranker:
"""Test mock reranker."""
def test_rerank_returns_scores(self):
"""Test that reranker returns scores."""
reranker = MockReranker()
query = "lead threshold contamination"
documents = [
"Lead threshold for facilities is 22 µg/100cm².",
"The weather is nice today.",
"Contamination levels require assessment.",
]
scores = reranker.rerank(query, documents)
assert len(scores) == 3
assert all(0 <= s <= 1 for s in scores)
def test_relevant_doc_higher_score(self):
"""Test that more relevant docs get higher scores."""
reranker = MockReranker()
query = "lead threshold"
documents = [
"Lead threshold is 22 µg.", # Very relevant
"Weather forecast for tomorrow.", # Not relevant
]
scores = reranker.rerank(query, documents)
# First doc should have higher score (shares more words)
assert scores[0] > scores[1]
class TestFDAMRetriever:
"""Test FDAM retriever with priority weighting."""
@pytest.fixture
def temp_dir(self):
"""Create a temporary directory."""
temp = tempfile.mkdtemp()
yield temp
shutil.rmtree(temp)
@pytest.fixture
def retriever(self, temp_dir):
"""Create a test retriever with sample data."""
vectorstore = ChromaVectorStore(
persist_directory=temp_dir,
embedding_function=MockEmbeddingFunction(),
)
# Add sample chunks
chunks = [
Chunk(
id="primary_001",
text="Lead threshold for non-operational is 22 µg/100cm² per FDAM.",
source="fdam.md",
category="thresholds",
section="## Thresholds",
priority="primary",
content_type="narrative",
keywords=["lead", "threshold", "non-operational"],
),
Chunk(
id="ref_001",
text="Lead clearance levels from BNL SOP.",
source="bnl.md",
category="thresholds",
section="## Attachment 9.3",
priority="reference-threshold",
content_type="table",
keywords=["lead", "clearance"],
),
Chunk(
id="ref_002",
text="General cleaning procedures for soot removal.",
source="cleaning.md",
category="cleaning-procedures",
section="## Methods",
priority="reference-narrative",
content_type="narrative",
keywords=["cleaning", "soot"],
),
]
vectorstore.add_chunks(chunks)
return FDAMRetriever(
vectorstore=vectorstore,
reranker=MockReranker(),
use_reranking=True,
)
def test_retrieve_returns_results(self, retriever):
"""Test basic retrieval."""
results = retriever.retrieve("lead threshold", top_k=3)
assert len(results) <= 3
assert all(isinstance(r, RetrievalResult) for r in results)
def test_priority_weighting(self, retriever):
"""Test that primary sources get higher weight."""
results = retriever.retrieve("lead threshold", top_k=3)
# Find primary and reference results
primary_results = [r for r in results if r.priority == "primary"]
ref_results = [r for r in results if r.priority != "primary"]
if primary_results and ref_results:
# Primary should have higher weighted score (before reranking)
# Note: final_score includes reranking which may change order
primary = primary_results[0]
ref = ref_results[0]
# With similar similarity, primary weight (1.0) > ref weight (0.8-0.9)
# This test validates the weighting is applied
assert primary.weighted_score > 0
def test_category_filter(self, retriever):
"""Test filtering by category."""
results = retriever.retrieve(
"cleaning method",
top_k=5,
category_filter="cleaning-procedures",
)
for r in results:
assert r.category == "cleaning-procedures"
def test_priority_filter(self, retriever):
"""Test filtering by priority."""
results = retriever.retrieve(
"threshold",
top_k=5,
priority_filter="primary",
)
for r in results:
assert r.priority == "primary"
def test_retrieve_for_context(self, retriever):
"""Test context string generation."""
context = retriever.retrieve_for_context("lead threshold", top_k=2)
assert isinstance(context, str)
assert "Source:" in context or "No relevant context" in context
def test_retrieve_thresholds(self, retriever):
"""Test threshold-specific retrieval."""
results = retriever.retrieve_thresholds(
material_type="lead",
facility_type="non-operational",
)
assert len(results) <= 3
# Should filter to thresholds category
for r in results:
assert r.category == "thresholds"
def test_retrieve_disposition(self, retriever):
"""Test disposition-specific retrieval."""
results = retriever.retrieve_disposition(
zone="burn-zone",
condition="heavy",
)
# Should prefer primary sources
if results:
assert results[0].priority == "primary"
def test_result_to_dict(self, retriever):
"""Test RetrievalResult to_dict method."""
results = retriever.retrieve("test", top_k=1)
if results:
result_dict = results[0].to_dict()
assert "chunk_id" in result_dict
assert "text" in result_dict
assert "source" in result_dict
assert "similarity_score" in result_dict
assert "final_score" in result_dict
def test_empty_query_handling(self, retriever):
"""Test handling of query with no good matches."""
results = retriever.retrieve(
"completely unrelated xyz123",
top_k=5,
category_filter="thresholds",
)
# Should still return results (just lower scores)
assert isinstance(results, list)
class TestChunkFile:
"""Test the chunk_file convenience function."""
@pytest.fixture
def temp_md_file(self):
"""Create a temporary markdown file."""
temp = tempfile.NamedTemporaryFile(
mode="w",
suffix=".md",
delete=False,
encoding="utf-8",
)
temp.write("""## Test Document
This is test content for chunking.
| Column A | Column B |
|----------|----------|
| Value 1 | Value 2 |
""")
temp.close()
yield Path(temp.name)
Path(temp.name).unlink()
def test_chunk_file(self, temp_md_file):
"""Test chunking a file directly."""
chunks = chunk_file(
filepath=temp_md_file,
category="methodology",
priority="primary",
)
assert len(chunks) >= 1
assert all(c.source == temp_md_file.name for c in chunks)
assert all(c.category == "methodology" for c in chunks)