| """Unit tests for SearchHandler.""" |
|
|
| from unittest.mock import AsyncMock, create_autospec |
|
|
| import pytest |
|
|
| from src.tools.base import SearchTool |
| from src.tools.search_handler import SearchHandler, deduplicate_evidence, extract_paper_id |
| from src.utils.exceptions import SearchError |
| from src.utils.models import Citation, Evidence |
|
|
|
|
| def _make_evidence(source: str, url: str, metadata: dict | None = None) -> Evidence: |
| """Helper to create Evidence objects for testing.""" |
| return Evidence( |
| content="Test content", |
| citation=Citation( |
| source=source, |
| title="Test", |
| url=url, |
| date="2024", |
| authors=[], |
| ), |
| metadata=metadata or {}, |
| ) |
|
|
|
|
| class TestExtractPaperId: |
| """Tests for paper ID extraction from Evidence objects.""" |
|
|
| def test_extracts_pubmed_id(self) -> None: |
| evidence = _make_evidence("pubmed", "https://pubmed.ncbi.nlm.nih.gov/12345678/") |
| assert extract_paper_id(evidence) == "PMID:12345678" |
|
|
| def test_extracts_europepmc_med_id(self) -> None: |
| evidence = _make_evidence("europepmc", "https://europepmc.org/article/MED/12345678") |
| assert extract_paper_id(evidence) == "PMID:12345678" |
|
|
| def test_extracts_europepmc_pmc_id(self) -> None: |
| """Europe PMC PMC articles have different ID format.""" |
| evidence = _make_evidence("europepmc", "https://europepmc.org/article/PMC/PMC7654321") |
| assert extract_paper_id(evidence) == "PMCID:PMC7654321" |
|
|
| def test_extracts_europepmc_ppr_id(self) -> None: |
| """Europe PMC preprints have PPR IDs.""" |
| evidence = _make_evidence("europepmc", "https://europepmc.org/article/PPR/PPR123456") |
| assert extract_paper_id(evidence) == "PPRID:PPR123456" |
|
|
| def test_extracts_europepmc_pat_id(self) -> None: |
| """Europe PMC patents have PAT IDs (WIPO format).""" |
| evidence = _make_evidence("europepmc", "https://europepmc.org/article/PAT/WO8601415") |
| assert extract_paper_id(evidence) == "PATID:WO8601415" |
|
|
| def test_extracts_europepmc_pat_id_eu_format(self) -> None: |
| """European patent format should also work.""" |
| evidence = _make_evidence("europepmc", "https://europepmc.org/article/PAT/EP1234567") |
| assert extract_paper_id(evidence) == "PATID:EP1234567" |
|
|
| def test_extracts_doi(self) -> None: |
| evidence = _make_evidence("pubmed", "https://doi.org/10.1038/nature12345") |
| assert extract_paper_id(evidence) == "DOI:10.1038/nature12345" |
|
|
| def test_extracts_doi_with_trailing_slash(self) -> None: |
| """DOIs should be normalized (trailing slash removed).""" |
| evidence = _make_evidence("pubmed", "https://doi.org/10.1038/nature12345/") |
| assert extract_paper_id(evidence) == "DOI:10.1038/nature12345" |
|
|
| def test_extracts_openalex_id_from_url(self) -> None: |
| """OpenAlex ID from URL (fallback when no PMID in metadata).""" |
| evidence = _make_evidence("openalex", "https://openalex.org/W1234567890") |
| assert extract_paper_id(evidence) == "OAID:W1234567890" |
|
|
| def test_extracts_openalex_pmid_from_metadata(self) -> None: |
| """OpenAlex PMID from metadata takes priority over URL.""" |
| evidence = _make_evidence( |
| "openalex", |
| "https://openalex.org/W1234567890", |
| metadata={"pmid": "98765432"}, |
| ) |
| assert extract_paper_id(evidence) == "PMID:98765432" |
|
|
| def test_extracts_nct_id_modern(self) -> None: |
| evidence = _make_evidence("clinicaltrials", "https://clinicaltrials.gov/study/NCT12345678") |
| assert extract_paper_id(evidence) == "NCT:NCT12345678" |
|
|
| def test_extracts_nct_id_legacy(self) -> None: |
| """Legacy ClinicalTrials.gov URL format should also work.""" |
| evidence = _make_evidence( |
| "clinicaltrials", "https://clinicaltrials.gov/ct2/show/NCT12345678" |
| ) |
| assert extract_paper_id(evidence) == "NCT:NCT12345678" |
|
|
| def test_returns_none_for_unknown_url(self) -> None: |
| evidence = _make_evidence("web", "https://example.com/unknown") |
| assert extract_paper_id(evidence) is None |
|
|
|
|
| class TestDeduplicateEvidence: |
| """Tests for evidence deduplication.""" |
|
|
| def test_removes_pubmed_europepmc_duplicate(self) -> None: |
| """Same paper from PubMed and Europe PMC should dedupe to PubMed.""" |
| pubmed = _make_evidence("pubmed", "https://pubmed.ncbi.nlm.nih.gov/12345678/") |
| europepmc = _make_evidence("europepmc", "https://europepmc.org/article/MED/12345678") |
|
|
| result = deduplicate_evidence([pubmed, europepmc]) |
|
|
| assert len(result) == 1 |
| assert result[0].citation.source == "pubmed" |
|
|
| def test_removes_pubmed_openalex_duplicate_via_metadata(self) -> None: |
| """OpenAlex with PMID in metadata should dedupe against PubMed.""" |
| pubmed = _make_evidence("pubmed", "https://pubmed.ncbi.nlm.nih.gov/12345678/") |
| openalex = _make_evidence( |
| "openalex", |
| "https://openalex.org/W9999999", |
| metadata={"pmid": "12345678", "cited_by_count": 100}, |
| ) |
|
|
| result = deduplicate_evidence([pubmed, openalex]) |
|
|
| assert len(result) == 1 |
| assert result[0].citation.source == "pubmed" |
|
|
| def test_preserves_unique_evidence(self) -> None: |
| """Different papers should not be deduplicated.""" |
| e1 = _make_evidence("pubmed", "https://pubmed.ncbi.nlm.nih.gov/11111111/") |
| e2 = _make_evidence("pubmed", "https://pubmed.ncbi.nlm.nih.gov/22222222/") |
|
|
| result = deduplicate_evidence([e1, e2]) |
|
|
| assert len(result) == 2 |
|
|
| def test_preserves_openalex_without_pmid(self) -> None: |
| """OpenAlex papers without PMID should NOT be deduplicated against PubMed.""" |
| pubmed = _make_evidence("pubmed", "https://pubmed.ncbi.nlm.nih.gov/12345678/") |
| openalex_no_pmid = _make_evidence( |
| "openalex", |
| "https://openalex.org/W9999999", |
| metadata={"cited_by_count": 100}, |
| ) |
|
|
| result = deduplicate_evidence([pubmed, openalex_no_pmid]) |
|
|
| assert len(result) == 2 |
|
|
| def test_keeps_unidentifiable_evidence(self) -> None: |
| """Evidence with unrecognized URLs should be preserved.""" |
| unknown = _make_evidence("web", "https://example.com/paper/123") |
|
|
| result = deduplicate_evidence([unknown]) |
|
|
| assert len(result) == 1 |
|
|
| def test_clinicaltrials_unique_per_nct(self) -> None: |
| """ClinicalTrials entries have unique NCT IDs.""" |
| trial1 = _make_evidence("clinicaltrials", "https://clinicaltrials.gov/study/NCT11111111") |
| trial2 = _make_evidence("clinicaltrials", "https://clinicaltrials.gov/study/NCT22222222") |
|
|
| result = deduplicate_evidence([trial1, trial2]) |
|
|
| assert len(result) == 2 |
|
|
| def test_preprints_preserved_separately(self) -> None: |
| """Preprints (PPR IDs) should not dedupe against peer-reviewed papers.""" |
| peer_reviewed = _make_evidence("pubmed", "https://pubmed.ncbi.nlm.nih.gov/12345678/") |
| preprint = _make_evidence("europepmc", "https://europepmc.org/article/PPR/PPR999999") |
|
|
| result = deduplicate_evidence([peer_reviewed, preprint]) |
|
|
| assert len(result) == 2 |
|
|
|
|
| class TestSearchHandler: |
| """Tests for SearchHandler.""" |
|
|
| @pytest.mark.asyncio |
| async def test_execute_aggregates_and_deduplicates(self): |
| """SearchHandler should aggregate results and deduplicate them.""" |
| |
| mock_tool1 = AsyncMock(spec=SearchTool) |
| mock_tool1.name = "pubmed" |
| mock_tool1.search.return_value = [ |
| _make_evidence("pubmed", "https://pubmed.ncbi.nlm.nih.gov/12345678/") |
| ] |
|
|
| mock_tool2 = AsyncMock(spec=SearchTool) |
| mock_tool2.name = "europepmc" |
| |
| mock_tool2.search.return_value = [ |
| _make_evidence("europepmc", "https://europepmc.org/article/MED/12345678") |
| ] |
|
|
| handler = SearchHandler(tools=[mock_tool1, mock_tool2]) |
|
|
| |
| result = await handler.execute("test") |
|
|
| |
| assert result.total_found == 1 |
| assert len(result.evidence) == 1 |
| assert result.evidence[0].citation.source == "pubmed" |
| assert "pubmed" in result.sources_searched |
| assert "europepmc" in result.sources_searched |
|
|
| @pytest.mark.asyncio |
| async def test_execute_handles_tool_failure(self): |
| """SearchHandler should continue if one tool fails.""" |
| mock_tool_ok = create_autospec(SearchTool, instance=True) |
| mock_tool_ok.name = "pubmed" |
| mock_tool_ok.search = AsyncMock( |
| return_value=[_make_evidence("pubmed", "https://pubmed.ncbi.nlm.nih.gov/12345678/")] |
| ) |
|
|
| mock_tool_fail = create_autospec(SearchTool, instance=True) |
| mock_tool_fail.name = "clinicaltrials" |
| mock_tool_fail.search = AsyncMock(side_effect=SearchError("API down")) |
|
|
| handler = SearchHandler(tools=[mock_tool_ok, mock_tool_fail]) |
| result = await handler.execute("test") |
|
|
| assert result.total_found == 1 |
| assert "pubmed" in result.sources_searched |
| assert len(result.errors) == 1 |
| assert "clinicaltrials: API down" in result.errors[0] |
|
|