Spaces:
Running
Running
File size: 7,466 Bytes
499170b 2c5db87 499170b 1bc9785 499170b 2c5db87 499170b 2c5db87 499170b 1bc9785 499170b 1bc9785 499170b 2c5db87 499170b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 |
"""Search handler - orchestrates multiple search tools."""
import asyncio
import re
from typing import TYPE_CHECKING, cast
import structlog
from src.tools.base import SearchTool
from src.utils.exceptions import SearchError
from src.utils.models import Evidence, SearchResult, SourceName
if TYPE_CHECKING:
from src.utils.models import Evidence
logger = structlog.get_logger()
def extract_paper_id(evidence: "Evidence") -> str | None:
"""Extract unique paper identifier from Evidence.
Strategy:
1. Check metadata.pmid first (OpenAlex provides this)
2. Fall back to URL pattern matching
Supports:
- PubMed: https://pubmed.ncbi.nlm.nih.gov/12345678/
- Europe PMC MED: https://europepmc.org/article/MED/12345678
- Europe PMC PMC: https://europepmc.org/article/PMC/PMC1234567
- Europe PMC PPR: https://europepmc.org/article/PPR/PPR123456
- Europe PMC PAT: https://europepmc.org/article/PAT/WO8601415
- DOI: https://doi.org/10.1234/...
- OpenAlex: https://openalex.org/W1234567890
- ClinicalTrials: https://clinicaltrials.gov/study/NCT12345678
- ClinicalTrials (legacy): https://clinicaltrials.gov/ct2/show/NCT12345678
"""
url = evidence.citation.url
metadata = evidence.metadata or {}
# Strategy 1: Check metadata.pmid (from OpenAlex)
if pmid := metadata.get("pmid"):
return f"PMID:{pmid}"
# Strategy 2: URL pattern matching
# PubMed URL pattern
pmid_match = re.search(r"pubmed\.ncbi\.nlm\.nih\.gov/(\d+)", url)
if pmid_match:
return f"PMID:{pmid_match.group(1)}"
# Europe PMC MED pattern (same as PMID)
epmc_med_match = re.search(r"europepmc\.org/article/MED/(\d+)", url)
if epmc_med_match:
return f"PMID:{epmc_med_match.group(1)}"
# Europe PMC PMC pattern (PubMed Central ID - different from PMID!)
epmc_pmc_match = re.search(r"europepmc\.org/article/PMC/(PMC\d+)", url)
if epmc_pmc_match:
return f"PMCID:{epmc_pmc_match.group(1)}"
# Europe PMC PPR pattern (Preprint ID - unique per preprint)
epmc_ppr_match = re.search(r"europepmc\.org/article/PPR/(PPR\d+)", url)
if epmc_ppr_match:
return f"PPRID:{epmc_ppr_match.group(1)}"
# Europe PMC PAT pattern (Patent ID - e.g., WO8601415, EP1234567)
epmc_pat_match = re.search(r"europepmc\.org/article/PAT/([A-Z]{2}\d+)", url)
if epmc_pat_match:
return f"PATID:{epmc_pat_match.group(1)}"
# DOI pattern (normalize trailing slash/characters)
doi_match = re.search(r"doi\.org/(10\.\d+/[^\s\]>]+)", url)
if doi_match:
doi = doi_match.group(1).rstrip("/")
return f"DOI:{doi}"
# OpenAlex ID pattern (fallback if no PMID in metadata)
openalex_match = re.search(r"openalex\.org/(W\d+)", url)
if openalex_match:
return f"OAID:{openalex_match.group(1)}"
# ClinicalTrials NCT ID (modern format)
nct_match = re.search(r"clinicaltrials\.gov/study/(NCT\d+)", url)
if nct_match:
return f"NCT:{nct_match.group(1)}"
# ClinicalTrials NCT ID (legacy format)
nct_legacy_match = re.search(r"clinicaltrials\.gov/ct2/show/(NCT\d+)", url)
if nct_legacy_match:
return f"NCT:{nct_legacy_match.group(1)}"
return None
def deduplicate_evidence(evidence_list: list["Evidence"]) -> list["Evidence"]:
"""Remove duplicate evidence based on paper ID.
Deduplication priority:
1. PubMed (authoritative source)
2. Europe PMC (full text links)
3. OpenAlex (citation data)
4. ClinicalTrials (unique, never duplicated)
Returns:
Deduplicated list preserving source priority order.
"""
seen_ids: set[str] = set()
unique: list[Evidence] = []
# Sort by source priority (PubMed first)
source_priority = {"pubmed": 0, "europepmc": 1, "openalex": 2, "clinicaltrials": 3}
sorted_evidence = sorted(
evidence_list, key=lambda e: source_priority.get(e.citation.source, 99)
)
for evidence in sorted_evidence:
paper_id = extract_paper_id(evidence)
if paper_id is None:
# Can't identify - keep it (conservative)
unique.append(evidence)
continue
if paper_id not in seen_ids:
seen_ids.add(paper_id)
unique.append(evidence)
return unique
class SearchHandler:
"""Orchestrates parallel searches across multiple tools."""
def __init__(self, tools: list[SearchTool], timeout: float = 30.0) -> None:
"""
Initialize the search handler.
Args:
tools: List of search tools to use
timeout: Timeout for each search in seconds
"""
self.tools = tools
self.timeout = timeout
async def execute(self, query: str, max_results_per_tool: int = 10) -> SearchResult:
"""
Execute search across all tools in parallel.
Args:
query: The search query
max_results_per_tool: Max results from each tool
Returns:
SearchResult containing all evidence and metadata
"""
logger.info("Starting search", query=query, tools=[t.name for t in self.tools])
# Create tasks for parallel execution
tasks = [
self._search_with_timeout(tool, query, max_results_per_tool) for tool in self.tools
]
# Gather results (don't fail if one tool fails)
results = await asyncio.gather(*tasks, return_exceptions=True)
# Process results
all_evidence: list[Evidence] = []
sources_searched: list[SourceName] = []
errors: list[str] = []
for tool, result in zip(self.tools, results, strict=True):
if isinstance(result, Exception):
errors.append(f"{tool.name}: {result!s}")
logger.warning("Search tool failed", tool=tool.name, error=str(result))
else:
# Cast result to list[Evidence] as we know it succeeded
success_result = cast(list[Evidence], result)
all_evidence.extend(success_result)
# Cast tool.name to SourceName (centralized type from models)
tool_name = cast(SourceName, tool.name)
sources_searched.append(tool_name)
logger.info("Search tool succeeded", tool=tool.name, count=len(success_result))
# DEDUPLICATION STEP
original_count = len(all_evidence)
all_evidence = deduplicate_evidence(all_evidence)
dedup_count = original_count - len(all_evidence)
if dedup_count > 0:
logger.info(
"Deduplicated evidence",
original=original_count,
unique=len(all_evidence),
removed=dedup_count,
)
return SearchResult(
query=query,
evidence=all_evidence,
sources_searched=sources_searched,
total_found=len(all_evidence),
errors=errors,
)
async def _search_with_timeout(
self,
tool: SearchTool,
query: str,
max_results: int,
) -> list[Evidence]:
"""Execute a single tool search with timeout."""
try:
return await asyncio.wait_for(
tool.search(query, max_results),
timeout=self.timeout,
)
except TimeoutError as e:
raise SearchError(f"{tool.name} search timed out after {self.timeout}s") from e
|