DeepBoner / src /tools /search_handler.py
VibecoderMcSwaggins's picture
feat(search): SPEC_13 Evidence Deduplication (#98)
2c5db87 unverified
raw
history blame
7.47 kB
"""Search handler - orchestrates multiple search tools."""
import asyncio
import re
from typing import TYPE_CHECKING, cast
import structlog
from src.tools.base import SearchTool
from src.utils.exceptions import SearchError
from src.utils.models import Evidence, SearchResult, SourceName
if TYPE_CHECKING:
from src.utils.models import Evidence
logger = structlog.get_logger()
def extract_paper_id(evidence: "Evidence") -> str | None:
"""Extract unique paper identifier from Evidence.
Strategy:
1. Check metadata.pmid first (OpenAlex provides this)
2. Fall back to URL pattern matching
Supports:
- PubMed: https://pubmed.ncbi.nlm.nih.gov/12345678/
- Europe PMC MED: https://europepmc.org/article/MED/12345678
- Europe PMC PMC: https://europepmc.org/article/PMC/PMC1234567
- Europe PMC PPR: https://europepmc.org/article/PPR/PPR123456
- Europe PMC PAT: https://europepmc.org/article/PAT/WO8601415
- DOI: https://doi.org/10.1234/...
- OpenAlex: https://openalex.org/W1234567890
- ClinicalTrials: https://clinicaltrials.gov/study/NCT12345678
- ClinicalTrials (legacy): https://clinicaltrials.gov/ct2/show/NCT12345678
"""
url = evidence.citation.url
metadata = evidence.metadata or {}
# Strategy 1: Check metadata.pmid (from OpenAlex)
if pmid := metadata.get("pmid"):
return f"PMID:{pmid}"
# Strategy 2: URL pattern matching
# PubMed URL pattern
pmid_match = re.search(r"pubmed\.ncbi\.nlm\.nih\.gov/(\d+)", url)
if pmid_match:
return f"PMID:{pmid_match.group(1)}"
# Europe PMC MED pattern (same as PMID)
epmc_med_match = re.search(r"europepmc\.org/article/MED/(\d+)", url)
if epmc_med_match:
return f"PMID:{epmc_med_match.group(1)}"
# Europe PMC PMC pattern (PubMed Central ID - different from PMID!)
epmc_pmc_match = re.search(r"europepmc\.org/article/PMC/(PMC\d+)", url)
if epmc_pmc_match:
return f"PMCID:{epmc_pmc_match.group(1)}"
# Europe PMC PPR pattern (Preprint ID - unique per preprint)
epmc_ppr_match = re.search(r"europepmc\.org/article/PPR/(PPR\d+)", url)
if epmc_ppr_match:
return f"PPRID:{epmc_ppr_match.group(1)}"
# Europe PMC PAT pattern (Patent ID - e.g., WO8601415, EP1234567)
epmc_pat_match = re.search(r"europepmc\.org/article/PAT/([A-Z]{2}\d+)", url)
if epmc_pat_match:
return f"PATID:{epmc_pat_match.group(1)}"
# DOI pattern (normalize trailing slash/characters)
doi_match = re.search(r"doi\.org/(10\.\d+/[^\s\]>]+)", url)
if doi_match:
doi = doi_match.group(1).rstrip("/")
return f"DOI:{doi}"
# OpenAlex ID pattern (fallback if no PMID in metadata)
openalex_match = re.search(r"openalex\.org/(W\d+)", url)
if openalex_match:
return f"OAID:{openalex_match.group(1)}"
# ClinicalTrials NCT ID (modern format)
nct_match = re.search(r"clinicaltrials\.gov/study/(NCT\d+)", url)
if nct_match:
return f"NCT:{nct_match.group(1)}"
# ClinicalTrials NCT ID (legacy format)
nct_legacy_match = re.search(r"clinicaltrials\.gov/ct2/show/(NCT\d+)", url)
if nct_legacy_match:
return f"NCT:{nct_legacy_match.group(1)}"
return None
def deduplicate_evidence(evidence_list: list["Evidence"]) -> list["Evidence"]:
"""Remove duplicate evidence based on paper ID.
Deduplication priority:
1. PubMed (authoritative source)
2. Europe PMC (full text links)
3. OpenAlex (citation data)
4. ClinicalTrials (unique, never duplicated)
Returns:
Deduplicated list preserving source priority order.
"""
seen_ids: set[str] = set()
unique: list[Evidence] = []
# Sort by source priority (PubMed first)
source_priority = {"pubmed": 0, "europepmc": 1, "openalex": 2, "clinicaltrials": 3}
sorted_evidence = sorted(
evidence_list, key=lambda e: source_priority.get(e.citation.source, 99)
)
for evidence in sorted_evidence:
paper_id = extract_paper_id(evidence)
if paper_id is None:
# Can't identify - keep it (conservative)
unique.append(evidence)
continue
if paper_id not in seen_ids:
seen_ids.add(paper_id)
unique.append(evidence)
return unique
class SearchHandler:
"""Orchestrates parallel searches across multiple tools."""
def __init__(self, tools: list[SearchTool], timeout: float = 30.0) -> None:
"""
Initialize the search handler.
Args:
tools: List of search tools to use
timeout: Timeout for each search in seconds
"""
self.tools = tools
self.timeout = timeout
async def execute(self, query: str, max_results_per_tool: int = 10) -> SearchResult:
"""
Execute search across all tools in parallel.
Args:
query: The search query
max_results_per_tool: Max results from each tool
Returns:
SearchResult containing all evidence and metadata
"""
logger.info("Starting search", query=query, tools=[t.name for t in self.tools])
# Create tasks for parallel execution
tasks = [
self._search_with_timeout(tool, query, max_results_per_tool) for tool in self.tools
]
# Gather results (don't fail if one tool fails)
results = await asyncio.gather(*tasks, return_exceptions=True)
# Process results
all_evidence: list[Evidence] = []
sources_searched: list[SourceName] = []
errors: list[str] = []
for tool, result in zip(self.tools, results, strict=True):
if isinstance(result, Exception):
errors.append(f"{tool.name}: {result!s}")
logger.warning("Search tool failed", tool=tool.name, error=str(result))
else:
# Cast result to list[Evidence] as we know it succeeded
success_result = cast(list[Evidence], result)
all_evidence.extend(success_result)
# Cast tool.name to SourceName (centralized type from models)
tool_name = cast(SourceName, tool.name)
sources_searched.append(tool_name)
logger.info("Search tool succeeded", tool=tool.name, count=len(success_result))
# DEDUPLICATION STEP
original_count = len(all_evidence)
all_evidence = deduplicate_evidence(all_evidence)
dedup_count = original_count - len(all_evidence)
if dedup_count > 0:
logger.info(
"Deduplicated evidence",
original=original_count,
unique=len(all_evidence),
removed=dedup_count,
)
return SearchResult(
query=query,
evidence=all_evidence,
sources_searched=sources_searched,
total_found=len(all_evidence),
errors=errors,
)
async def _search_with_timeout(
self,
tool: SearchTool,
query: str,
max_results: int,
) -> list[Evidence]:
"""Execute a single tool search with timeout."""
try:
return await asyncio.wait_for(
tool.search(query, max_results),
timeout=self.timeout,
)
except TimeoutError as e:
raise SearchError(f"{tool.name} search timed out after {self.timeout}s") from e