| """Unit tests for text utilities.""" |
|
|
| from unittest.mock import AsyncMock, MagicMock |
|
|
| import pytest |
|
|
| from src.utils.models import Citation, Evidence |
| from src.utils.text_utils import select_diverse_evidence, truncate_at_sentence |
|
|
|
|
| class TestTextUtils: |
| def test_truncate_at_sentence_short(self): |
| """Should return text as is if shorter than limit.""" |
| text = "This is a short sentence." |
| assert truncate_at_sentence(text, 100) == text |
|
|
| def test_truncate_at_sentence_boundary(self): |
| """Should truncate at sentence ending.""" |
| text = "First sentence. Second sentence. Third sentence." |
| |
| limit = len("First sentence. Second sentence") + 5 |
| result = truncate_at_sentence(text, limit) |
| assert result == "First sentence. Second sentence." |
|
|
| def test_truncate_at_sentence_fallback_period(self): |
| """Should fall back to period if no sentence boundary found.""" |
| text = "Dr. Smith went to the store. He bought apples." |
| |
| limit = len("Dr. Smith went to the store.") + 5 |
| result = truncate_at_sentence(text, limit) |
| assert result == "Dr. Smith went to the store." |
|
|
| def test_truncate_at_sentence_fallback_word(self): |
| """Should fall back to word boundary if no punctuation.""" |
| text = "This is a very long sentence without any punctuation marks until the very end" |
| limit = 20 |
| result = truncate_at_sentence(text, limit) |
| assert result == "This is a very long..." |
| |
| assert len(result) <= limit + 3 |
|
|
| @pytest.mark.asyncio |
| async def test_select_diverse_evidence_no_embeddings(self): |
| """Should fallback to relevance sort if no embeddings.""" |
| evidence = [ |
| Evidence( |
| content="A", |
| relevance=0.9, |
| citation=Citation(source="pubmed", title="A", url="a", date="2023"), |
| ), |
| Evidence( |
| content="B", |
| relevance=0.1, |
| citation=Citation(source="pubmed", title="B", url="b", date="2023"), |
| ), |
| Evidence( |
| content="C", |
| relevance=0.8, |
| citation=Citation(source="pubmed", title="C", url="c", date="2023"), |
| ), |
| ] |
|
|
| selected = await select_diverse_evidence(evidence, n=2, query="test", embeddings=None) |
|
|
| expected_count = 2 |
| assert len(selected) == expected_count |
| assert selected[0].content == "A" |
| assert selected[1].content == "C" |
|
|
| @pytest.mark.asyncio |
| async def test_select_diverse_evidence_mmr(self): |
| """Should select diverse evidence using MMR.""" |
| |
| mock_embeddings = MagicMock() |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| async def mock_embed(text): |
| if text == "query": |
| return [0.707, 0.707] |
| return [0.0, 0.0] |
|
|
| async def mock_embed_batch(texts): |
| results = [] |
| for t in texts: |
| if t == "A": |
| results.append([1.0, 0.0]) |
| elif t == "B": |
| results.append([1.0, 0.0]) |
| elif t == "C": |
| results.append([0.0, 1.0]) |
| else: |
| results.append([0.0, 0.0]) |
| return results |
|
|
| mock_embeddings.embed = AsyncMock(side_effect=mock_embed) |
| mock_embeddings.embed_batch = AsyncMock(side_effect=mock_embed_batch) |
|
|
| evidence = [ |
| Evidence( |
| content="A", |
| relevance=0.9, |
| citation=Citation(source="pubmed", title="A", url="a", date="2023"), |
| ), |
| Evidence( |
| content="B", |
| relevance=0.9, |
| citation=Citation(source="pubmed", title="B", url="b", date="2023"), |
| ), |
| Evidence( |
| content="C", |
| relevance=0.9, |
| citation=Citation(source="pubmed", title="C", url="c", date="2023"), |
| ), |
| ] |
|
|
| |
| selected = await select_diverse_evidence( |
| evidence, n=2, query="query", embeddings=mock_embeddings |
| ) |
|
|
| expected_count = 2 |
| assert len(selected) == expected_count |
| assert selected[0].content == "A" |
| assert selected[1].content == "C" |
|
|