| """Search handler - orchestrates multiple search tools.""" |
|
|
| import asyncio |
| import re |
| from typing import TYPE_CHECKING, cast |
|
|
| import structlog |
|
|
| from src.tools.base import SearchTool |
| from src.utils.exceptions import SearchError |
| from src.utils.models import Evidence, SearchResult, SourceName |
|
|
| if TYPE_CHECKING: |
| from src.utils.models import Evidence |
|
|
| logger = structlog.get_logger() |
|
|
|
|
| def extract_paper_id(evidence: "Evidence") -> str | None: |
| """Extract unique paper identifier from Evidence. |
| |
| Strategy: |
| 1. Check metadata.pmid first (OpenAlex provides this) |
| 2. Fall back to URL pattern matching |
| |
| Supports: |
| - PubMed: https://pubmed.ncbi.nlm.nih.gov/12345678/ |
| - Europe PMC MED: https://europepmc.org/article/MED/12345678 |
| - Europe PMC PMC: https://europepmc.org/article/PMC/PMC1234567 |
| - Europe PMC PPR: https://europepmc.org/article/PPR/PPR123456 |
| - Europe PMC PAT: https://europepmc.org/article/PAT/WO8601415 |
| - DOI: https://doi.org/10.1234/... |
| - OpenAlex: https://openalex.org/W1234567890 |
| - ClinicalTrials: https://clinicaltrials.gov/study/NCT12345678 |
| - ClinicalTrials (legacy): https://clinicaltrials.gov/ct2/show/NCT12345678 |
| """ |
| url = evidence.citation.url |
| metadata = evidence.metadata or {} |
|
|
| |
| if pmid := metadata.get("pmid"): |
| return f"PMID:{pmid}" |
|
|
| |
|
|
| |
| pmid_match = re.search(r"pubmed\.ncbi\.nlm\.nih\.gov/(\d+)", url) |
| if pmid_match: |
| return f"PMID:{pmid_match.group(1)}" |
|
|
| |
| epmc_med_match = re.search(r"europepmc\.org/article/MED/(\d+)", url) |
| if epmc_med_match: |
| return f"PMID:{epmc_med_match.group(1)}" |
|
|
| |
| epmc_pmc_match = re.search(r"europepmc\.org/article/PMC/(PMC\d+)", url) |
| if epmc_pmc_match: |
| return f"PMCID:{epmc_pmc_match.group(1)}" |
|
|
| |
| epmc_ppr_match = re.search(r"europepmc\.org/article/PPR/(PPR\d+)", url) |
| if epmc_ppr_match: |
| return f"PPRID:{epmc_ppr_match.group(1)}" |
|
|
| |
| epmc_pat_match = re.search(r"europepmc\.org/article/PAT/([A-Z]{2}\d+)", url) |
| if epmc_pat_match: |
| return f"PATID:{epmc_pat_match.group(1)}" |
|
|
| |
| doi_match = re.search(r"doi\.org/(10\.\d+/[^\s\]>]+)", url) |
| if doi_match: |
| doi = doi_match.group(1).rstrip("/") |
| return f"DOI:{doi}" |
|
|
| |
| openalex_match = re.search(r"openalex\.org/(W\d+)", url) |
| if openalex_match: |
| return f"OAID:{openalex_match.group(1)}" |
|
|
| |
| nct_match = re.search(r"clinicaltrials\.gov/study/(NCT\d+)", url) |
| if nct_match: |
| return f"NCT:{nct_match.group(1)}" |
|
|
| |
| nct_legacy_match = re.search(r"clinicaltrials\.gov/ct2/show/(NCT\d+)", url) |
| if nct_legacy_match: |
| return f"NCT:{nct_legacy_match.group(1)}" |
|
|
| return None |
|
|
|
|
| def deduplicate_evidence(evidence_list: list["Evidence"]) -> list["Evidence"]: |
| """Remove duplicate evidence based on paper ID. |
| |
| Deduplication priority: |
| 1. PubMed (authoritative source) |
| 2. Europe PMC (full text links) |
| 3. OpenAlex (citation data) |
| 4. ClinicalTrials (unique, never duplicated) |
| |
| Returns: |
| Deduplicated list preserving source priority order. |
| """ |
| seen_ids: set[str] = set() |
| unique: list[Evidence] = [] |
|
|
| |
| source_priority = {"pubmed": 0, "europepmc": 1, "openalex": 2, "clinicaltrials": 3} |
| sorted_evidence = sorted( |
| evidence_list, key=lambda e: source_priority.get(e.citation.source, 99) |
| ) |
|
|
| for evidence in sorted_evidence: |
| paper_id = extract_paper_id(evidence) |
|
|
| if paper_id is None: |
| |
| unique.append(evidence) |
| continue |
|
|
| if paper_id not in seen_ids: |
| seen_ids.add(paper_id) |
| unique.append(evidence) |
|
|
| return unique |
|
|
|
|
| class SearchHandler: |
| """Orchestrates parallel searches across multiple tools.""" |
|
|
| def __init__(self, tools: list[SearchTool], timeout: float = 30.0) -> None: |
| """ |
| Initialize the search handler. |
| |
| Args: |
| tools: List of search tools to use |
| timeout: Timeout for each search in seconds |
| """ |
| self.tools = tools |
| self.timeout = timeout |
|
|
| async def execute(self, query: str, max_results_per_tool: int = 10) -> SearchResult: |
| """ |
| Execute search across all tools in parallel. |
| |
| Args: |
| query: The search query |
| max_results_per_tool: Max results from each tool |
| |
| Returns: |
| SearchResult containing all evidence and metadata |
| """ |
| logger.info("Starting search", query=query, tools=[t.name for t in self.tools]) |
|
|
| |
| tasks = [ |
| self._search_with_timeout(tool, query, max_results_per_tool) for tool in self.tools |
| ] |
|
|
| |
| results = await asyncio.gather(*tasks, return_exceptions=True) |
|
|
| |
| all_evidence: list[Evidence] = [] |
| sources_searched: list[SourceName] = [] |
| errors: list[str] = [] |
|
|
| for tool, result in zip(self.tools, results, strict=True): |
| if isinstance(result, Exception): |
| errors.append(f"{tool.name}: {result!s}") |
| logger.warning("Search tool failed", tool=tool.name, error=str(result)) |
| else: |
| |
| success_result = cast(list[Evidence], result) |
| all_evidence.extend(success_result) |
|
|
| |
| tool_name = cast(SourceName, tool.name) |
| sources_searched.append(tool_name) |
| logger.info("Search tool succeeded", tool=tool.name, count=len(success_result)) |
|
|
| |
| original_count = len(all_evidence) |
| all_evidence = deduplicate_evidence(all_evidence) |
| dedup_count = original_count - len(all_evidence) |
|
|
| if dedup_count > 0: |
| logger.info( |
| "Deduplicated evidence", |
| original=original_count, |
| unique=len(all_evidence), |
| removed=dedup_count, |
| ) |
|
|
| return SearchResult( |
| query=query, |
| evidence=all_evidence, |
| sources_searched=sources_searched, |
| total_found=len(all_evidence), |
| errors=errors, |
| ) |
|
|
| async def _search_with_timeout( |
| self, |
| tool: SearchTool, |
| query: str, |
| max_results: int, |
| ) -> list[Evidence]: |
| """Execute a single tool search with timeout.""" |
| try: |
| return await asyncio.wait_for( |
| tool.search(query, max_results), |
| timeout=self.timeout, |
| ) |
| except TimeoutError as e: |
| raise SearchError(f"{tool.name} search timed out after {self.timeout}s") from e |
|
|