"""Search handler - orchestrates multiple search tools.""" import asyncio import re from typing import TYPE_CHECKING, cast import structlog from src.tools.base import SearchTool from src.utils.exceptions import SearchError from src.utils.models import Evidence, SearchResult, SourceName if TYPE_CHECKING: from src.utils.models import Evidence logger = structlog.get_logger() def extract_paper_id(evidence: "Evidence") -> str | None: """Extract unique paper identifier from Evidence. Strategy: 1. Check metadata.pmid first (OpenAlex provides this) 2. Fall back to URL pattern matching Supports: - PubMed: https://pubmed.ncbi.nlm.nih.gov/12345678/ - Europe PMC MED: https://europepmc.org/article/MED/12345678 - Europe PMC PMC: https://europepmc.org/article/PMC/PMC1234567 - Europe PMC PPR: https://europepmc.org/article/PPR/PPR123456 - Europe PMC PAT: https://europepmc.org/article/PAT/WO8601415 - DOI: https://doi.org/10.1234/... - OpenAlex: https://openalex.org/W1234567890 - ClinicalTrials: https://clinicaltrials.gov/study/NCT12345678 - ClinicalTrials (legacy): https://clinicaltrials.gov/ct2/show/NCT12345678 """ url = evidence.citation.url metadata = evidence.metadata or {} # Strategy 1: Check metadata.pmid (from OpenAlex) if pmid := metadata.get("pmid"): return f"PMID:{pmid}" # Strategy 2: URL pattern matching # PubMed URL pattern pmid_match = re.search(r"pubmed\.ncbi\.nlm\.nih\.gov/(\d+)", url) if pmid_match: return f"PMID:{pmid_match.group(1)}" # Europe PMC MED pattern (same as PMID) epmc_med_match = re.search(r"europepmc\.org/article/MED/(\d+)", url) if epmc_med_match: return f"PMID:{epmc_med_match.group(1)}" # Europe PMC PMC pattern (PubMed Central ID - different from PMID!) epmc_pmc_match = re.search(r"europepmc\.org/article/PMC/(PMC\d+)", url) if epmc_pmc_match: return f"PMCID:{epmc_pmc_match.group(1)}" # Europe PMC PPR pattern (Preprint ID - unique per preprint) epmc_ppr_match = re.search(r"europepmc\.org/article/PPR/(PPR\d+)", url) if epmc_ppr_match: return f"PPRID:{epmc_ppr_match.group(1)}" # Europe PMC PAT pattern (Patent ID - e.g., WO8601415, EP1234567) epmc_pat_match = re.search(r"europepmc\.org/article/PAT/([A-Z]{2}\d+)", url) if epmc_pat_match: return f"PATID:{epmc_pat_match.group(1)}" # DOI pattern (normalize trailing slash/characters) doi_match = re.search(r"doi\.org/(10\.\d+/[^\s\]>]+)", url) if doi_match: doi = doi_match.group(1).rstrip("/") return f"DOI:{doi}" # OpenAlex ID pattern (fallback if no PMID in metadata) openalex_match = re.search(r"openalex\.org/(W\d+)", url) if openalex_match: return f"OAID:{openalex_match.group(1)}" # ClinicalTrials NCT ID (modern format) nct_match = re.search(r"clinicaltrials\.gov/study/(NCT\d+)", url) if nct_match: return f"NCT:{nct_match.group(1)}" # ClinicalTrials NCT ID (legacy format) nct_legacy_match = re.search(r"clinicaltrials\.gov/ct2/show/(NCT\d+)", url) if nct_legacy_match: return f"NCT:{nct_legacy_match.group(1)}" return None def deduplicate_evidence(evidence_list: list["Evidence"]) -> list["Evidence"]: """Remove duplicate evidence based on paper ID. Deduplication priority: 1. PubMed (authoritative source) 2. Europe PMC (full text links) 3. OpenAlex (citation data) 4. ClinicalTrials (unique, never duplicated) Returns: Deduplicated list preserving source priority order. """ seen_ids: set[str] = set() unique: list[Evidence] = [] # Sort by source priority (PubMed first) source_priority = {"pubmed": 0, "europepmc": 1, "openalex": 2, "clinicaltrials": 3} sorted_evidence = sorted( evidence_list, key=lambda e: source_priority.get(e.citation.source, 99) ) for evidence in sorted_evidence: paper_id = extract_paper_id(evidence) if paper_id is None: # Can't identify - keep it (conservative) unique.append(evidence) continue if paper_id not in seen_ids: seen_ids.add(paper_id) unique.append(evidence) return unique class SearchHandler: """Orchestrates parallel searches across multiple tools.""" def __init__(self, tools: list[SearchTool], timeout: float = 30.0) -> None: """ Initialize the search handler. Args: tools: List of search tools to use timeout: Timeout for each search in seconds """ self.tools = tools self.timeout = timeout async def execute(self, query: str, max_results_per_tool: int = 10) -> SearchResult: """ Execute search across all tools in parallel. Args: query: The search query max_results_per_tool: Max results from each tool Returns: SearchResult containing all evidence and metadata """ logger.info("Starting search", query=query, tools=[t.name for t in self.tools]) # Create tasks for parallel execution tasks = [ self._search_with_timeout(tool, query, max_results_per_tool) for tool in self.tools ] # Gather results (don't fail if one tool fails) results = await asyncio.gather(*tasks, return_exceptions=True) # Process results all_evidence: list[Evidence] = [] sources_searched: list[SourceName] = [] errors: list[str] = [] for tool, result in zip(self.tools, results, strict=True): if isinstance(result, Exception): errors.append(f"{tool.name}: {result!s}") logger.warning("Search tool failed", tool=tool.name, error=str(result)) else: # Cast result to list[Evidence] as we know it succeeded success_result = cast(list[Evidence], result) all_evidence.extend(success_result) # Cast tool.name to SourceName (centralized type from models) tool_name = cast(SourceName, tool.name) sources_searched.append(tool_name) logger.info("Search tool succeeded", tool=tool.name, count=len(success_result)) # DEDUPLICATION STEP original_count = len(all_evidence) all_evidence = deduplicate_evidence(all_evidence) dedup_count = original_count - len(all_evidence) if dedup_count > 0: logger.info( "Deduplicated evidence", original=original_count, unique=len(all_evidence), removed=dedup_count, ) return SearchResult( query=query, evidence=all_evidence, sources_searched=sources_searched, total_found=len(all_evidence), errors=errors, ) async def _search_with_timeout( self, tool: SearchTool, query: str, max_results: int, ) -> list[Evidence]: """Execute a single tool search with timeout.""" try: return await asyncio.wait_for( tool.search(query, max_results), timeout=self.timeout, ) except TimeoutError as e: raise SearchError(f"{tool.name} search timed out after {self.timeout}s") from e