DeepBoner / tests /unit /tools /test_search_handler.py
VibecoderMcSwaggins's picture
feat(search): SPEC_13 Evidence Deduplication (#98)
2c5db87 unverified
"""Unit tests for SearchHandler."""
from unittest.mock import AsyncMock, create_autospec
import pytest
from src.tools.base import SearchTool
from src.tools.search_handler import SearchHandler, deduplicate_evidence, extract_paper_id
from src.utils.exceptions import SearchError
from src.utils.models import Citation, Evidence
def _make_evidence(source: str, url: str, metadata: dict | None = None) -> Evidence:
"""Helper to create Evidence objects for testing."""
return Evidence(
content="Test content",
citation=Citation(
source=source,
title="Test",
url=url,
date="2024",
authors=[],
),
metadata=metadata or {},
)
class TestExtractPaperId:
"""Tests for paper ID extraction from Evidence objects."""
def test_extracts_pubmed_id(self) -> None:
evidence = _make_evidence("pubmed", "https://pubmed.ncbi.nlm.nih.gov/12345678/")
assert extract_paper_id(evidence) == "PMID:12345678"
def test_extracts_europepmc_med_id(self) -> None:
evidence = _make_evidence("europepmc", "https://europepmc.org/article/MED/12345678")
assert extract_paper_id(evidence) == "PMID:12345678"
def test_extracts_europepmc_pmc_id(self) -> None:
"""Europe PMC PMC articles have different ID format."""
evidence = _make_evidence("europepmc", "https://europepmc.org/article/PMC/PMC7654321")
assert extract_paper_id(evidence) == "PMCID:PMC7654321"
def test_extracts_europepmc_ppr_id(self) -> None:
"""Europe PMC preprints have PPR IDs."""
evidence = _make_evidence("europepmc", "https://europepmc.org/article/PPR/PPR123456")
assert extract_paper_id(evidence) == "PPRID:PPR123456"
def test_extracts_europepmc_pat_id(self) -> None:
"""Europe PMC patents have PAT IDs (WIPO format)."""
evidence = _make_evidence("europepmc", "https://europepmc.org/article/PAT/WO8601415")
assert extract_paper_id(evidence) == "PATID:WO8601415"
def test_extracts_europepmc_pat_id_eu_format(self) -> None:
"""European patent format should also work."""
evidence = _make_evidence("europepmc", "https://europepmc.org/article/PAT/EP1234567")
assert extract_paper_id(evidence) == "PATID:EP1234567"
def test_extracts_doi(self) -> None:
evidence = _make_evidence("pubmed", "https://doi.org/10.1038/nature12345")
assert extract_paper_id(evidence) == "DOI:10.1038/nature12345"
def test_extracts_doi_with_trailing_slash(self) -> None:
"""DOIs should be normalized (trailing slash removed)."""
evidence = _make_evidence("pubmed", "https://doi.org/10.1038/nature12345/")
assert extract_paper_id(evidence) == "DOI:10.1038/nature12345"
def test_extracts_openalex_id_from_url(self) -> None:
"""OpenAlex ID from URL (fallback when no PMID in metadata)."""
evidence = _make_evidence("openalex", "https://openalex.org/W1234567890")
assert extract_paper_id(evidence) == "OAID:W1234567890"
def test_extracts_openalex_pmid_from_metadata(self) -> None:
"""OpenAlex PMID from metadata takes priority over URL."""
evidence = _make_evidence(
"openalex",
"https://openalex.org/W1234567890",
metadata={"pmid": "98765432"},
)
assert extract_paper_id(evidence) == "PMID:98765432"
def test_extracts_nct_id_modern(self) -> None:
evidence = _make_evidence("clinicaltrials", "https://clinicaltrials.gov/study/NCT12345678")
assert extract_paper_id(evidence) == "NCT:NCT12345678"
def test_extracts_nct_id_legacy(self) -> None:
"""Legacy ClinicalTrials.gov URL format should also work."""
evidence = _make_evidence(
"clinicaltrials", "https://clinicaltrials.gov/ct2/show/NCT12345678"
)
assert extract_paper_id(evidence) == "NCT:NCT12345678"
def test_returns_none_for_unknown_url(self) -> None:
evidence = _make_evidence("web", "https://example.com/unknown")
assert extract_paper_id(evidence) is None
class TestDeduplicateEvidence:
"""Tests for evidence deduplication."""
def test_removes_pubmed_europepmc_duplicate(self) -> None:
"""Same paper from PubMed and Europe PMC should dedupe to PubMed."""
pubmed = _make_evidence("pubmed", "https://pubmed.ncbi.nlm.nih.gov/12345678/")
europepmc = _make_evidence("europepmc", "https://europepmc.org/article/MED/12345678")
result = deduplicate_evidence([pubmed, europepmc])
assert len(result) == 1
assert result[0].citation.source == "pubmed"
def test_removes_pubmed_openalex_duplicate_via_metadata(self) -> None:
"""OpenAlex with PMID in metadata should dedupe against PubMed."""
pubmed = _make_evidence("pubmed", "https://pubmed.ncbi.nlm.nih.gov/12345678/")
openalex = _make_evidence(
"openalex",
"https://openalex.org/W9999999",
metadata={"pmid": "12345678", "cited_by_count": 100},
)
result = deduplicate_evidence([pubmed, openalex])
assert len(result) == 1
assert result[0].citation.source == "pubmed"
def test_preserves_unique_evidence(self) -> None:
"""Different papers should not be deduplicated."""
e1 = _make_evidence("pubmed", "https://pubmed.ncbi.nlm.nih.gov/11111111/")
e2 = _make_evidence("pubmed", "https://pubmed.ncbi.nlm.nih.gov/22222222/")
result = deduplicate_evidence([e1, e2])
assert len(result) == 2
def test_preserves_openalex_without_pmid(self) -> None:
"""OpenAlex papers without PMID should NOT be deduplicated against PubMed."""
pubmed = _make_evidence("pubmed", "https://pubmed.ncbi.nlm.nih.gov/12345678/")
openalex_no_pmid = _make_evidence(
"openalex",
"https://openalex.org/W9999999",
metadata={"cited_by_count": 100}, # No pmid key
)
result = deduplicate_evidence([pubmed, openalex_no_pmid])
assert len(result) == 2 # Both preserved (different IDs)
def test_keeps_unidentifiable_evidence(self) -> None:
"""Evidence with unrecognized URLs should be preserved."""
unknown = _make_evidence("web", "https://example.com/paper/123")
result = deduplicate_evidence([unknown])
assert len(result) == 1
def test_clinicaltrials_unique_per_nct(self) -> None:
"""ClinicalTrials entries have unique NCT IDs."""
trial1 = _make_evidence("clinicaltrials", "https://clinicaltrials.gov/study/NCT11111111")
trial2 = _make_evidence("clinicaltrials", "https://clinicaltrials.gov/study/NCT22222222")
result = deduplicate_evidence([trial1, trial2])
assert len(result) == 2
def test_preprints_preserved_separately(self) -> None:
"""Preprints (PPR IDs) should not dedupe against peer-reviewed papers."""
peer_reviewed = _make_evidence("pubmed", "https://pubmed.ncbi.nlm.nih.gov/12345678/")
preprint = _make_evidence("europepmc", "https://europepmc.org/article/PPR/PPR999999")
result = deduplicate_evidence([peer_reviewed, preprint])
assert len(result) == 2 # Both preserved (different ID types)
class TestSearchHandler:
"""Tests for SearchHandler."""
@pytest.mark.asyncio
async def test_execute_aggregates_and_deduplicates(self):
"""SearchHandler should aggregate results and deduplicate them."""
# Setup
mock_tool1 = AsyncMock(spec=SearchTool)
mock_tool1.name = "pubmed"
mock_tool1.search.return_value = [
_make_evidence("pubmed", "https://pubmed.ncbi.nlm.nih.gov/12345678/")
]
mock_tool2 = AsyncMock(spec=SearchTool)
mock_tool2.name = "europepmc"
# Duplicate of the pubmed result
mock_tool2.search.return_value = [
_make_evidence("europepmc", "https://europepmc.org/article/MED/12345678")
]
handler = SearchHandler(tools=[mock_tool1, mock_tool2])
# Execute
result = await handler.execute("test")
# Should only have 1 result after deduplication
assert result.total_found == 1
assert len(result.evidence) == 1
assert result.evidence[0].citation.source == "pubmed" # Priority source kept
assert "pubmed" in result.sources_searched
assert "europepmc" in result.sources_searched
@pytest.mark.asyncio
async def test_execute_handles_tool_failure(self):
"""SearchHandler should continue if one tool fails."""
mock_tool_ok = create_autospec(SearchTool, instance=True)
mock_tool_ok.name = "pubmed"
mock_tool_ok.search = AsyncMock(
return_value=[_make_evidence("pubmed", "https://pubmed.ncbi.nlm.nih.gov/12345678/")]
)
mock_tool_fail = create_autospec(SearchTool, instance=True)
mock_tool_fail.name = "clinicaltrials"
mock_tool_fail.search = AsyncMock(side_effect=SearchError("API down"))
handler = SearchHandler(tools=[mock_tool_ok, mock_tool_fail])
result = await handler.execute("test")
assert result.total_found == 1
assert "pubmed" in result.sources_searched
assert len(result.errors) == 1
assert "clinicaltrials: API down" in result.errors[0]