"""Explicit research tools used by the explore phase.""" from __future__ import annotations import asyncio import html import re import xml.etree.ElementTree as ET from typing import Awaitable, Callable import httpx import wikipediaapi try: from .. import constants as _constants except ImportError: # pragma: no cover - supports direct test execution import constants as _constants from .retrieval import MAX_RETURNED_CHUNKS, chunk_markdown, rank_chunks_for_query, trim_text from .types import ResearchChunk, ResearchResult AVAILABLE_TOOLS = _constants.AVAILABLE_TOOLS MAX_SOURCE_RESULTS = 5 async def run_research_tool(tool: str, query: str, intent: str = "") -> ResearchResult: """Run a named research tool and return structured chunks.""" tool = tool.strip() query = query.strip() if tool not in _TOOL_RUNNERS: return ResearchResult(tool=tool or "(missing)", query=query, error="Unknown research tool") if not query: return ResearchResult(tool=tool, query=query, error="Empty query") return await _TOOL_RUNNERS[tool](query, intent) async def search_wikipedia(query: str, intent: str = "") -> ResearchResult: try: wiki = wikipediaapi.AsyncWikipedia( user_agent="ExplainerEnv/1.0 (hackathon project)", language="en", ) search_results = await wiki.search(query, limit=MAX_SOURCE_RESULTS) titles = list(search_results.pages) if search_results and search_results.pages else [] chunks: list[ResearchChunk] = [] for title in titles: page = wiki.page(title) if not await page.exists(): continue summary = await page.summary sections = await page.sections docs: list[tuple[str, str]] = [] if summary: docs.append((title, summary)) docs.extend(_flatten_wikipedia_sections(sections)) for section_title, text in docs: snippet = trim_text(text) if snippet: chunks.append( ResearchChunk( source="wikipedia", tool="search_wikipedia", title=f"{title} - {section_title}", url=f"https://en.wikipedia.org/wiki/{title.replace(' ', '_')}", text=snippet, metadata={"page": title}, ) ) ranked = rank_chunks_for_query(query, intent, chunks, MAX_RETURNED_CHUNKS) return ResearchResult("search_wikipedia", query, ranked, raw_count=len(chunks)) except Exception as exc: return ResearchResult("search_wikipedia", query, error=str(exc)) async def search_hf_papers(query: str, intent: str = "") -> ResearchResult: try: async with httpx.AsyncClient(timeout=15.0, follow_redirects=True) as client: resp = await client.get( "https://huggingface.co/api/papers/search", params={"q": query, "limit": MAX_SOURCE_RESULTS}, headers={"User-Agent": "ExplainerEnv/1.0"}, ) resp.raise_for_status() papers = resp.json() or [] chunks: list[ResearchChunk] = [] for paper in papers[:MAX_SOURCE_RESULTS]: paper_id = paper.get("id", "") title = paper.get("title", "Untitled") url = f"https://huggingface.co/papers/{paper_id}" if paper_id else "" text_chunks = [("Abstract", paper.get("summary", ""))] if paper_id: md_resp = await client.get( f"https://huggingface.co/papers/{paper_id}.md", headers={"User-Agent": "ExplainerEnv/1.0"}, ) if md_resp.status_code == 200 and md_resp.text.strip(): text_chunks = chunk_markdown(md_resp.text, "Abstract") for section_title, text in text_chunks: snippet = trim_text(text) if snippet: chunks.append( ResearchChunk( source="hf_papers", tool="search_hf_papers", title=f"{title} - {section_title}", url=url, text=snippet, metadata={"paper_id": paper_id}, ) ) ranked = rank_chunks_for_query(query, intent, chunks, MAX_RETURNED_CHUNKS) return ResearchResult("search_hf_papers", query, ranked, raw_count=len(chunks)) except Exception as exc: return ResearchResult("search_hf_papers", query, error=str(exc)) async def search_arxiv(query: str, intent: str = "") -> ResearchResult: try: async with httpx.AsyncClient(timeout=15.0) as client: resp = await client.get( "https://export.arxiv.org/api/query", params={ "search_query": f"all:{query}", "start": 0, "max_results": MAX_SOURCE_RESULTS, "sortBy": "relevance", "sortOrder": "descending", }, headers={"User-Agent": "ExplainerEnv/1.0"}, ) resp.raise_for_status() root = ET.fromstring(resp.text) ns = {"atom": "http://www.w3.org/2005/Atom", "arxiv": "http://arxiv.org/schemas/atom"} chunks: list[ResearchChunk] = [] for entry in root.findall("atom:entry", ns): title = _xml_text(entry, "atom:title", ns) or "Untitled" summary = _xml_text(entry, "atom:summary", ns) url = _xml_text(entry, "atom:id", ns) published = _xml_text(entry, "atom:published", ns) authors = [ _xml_text(author, "atom:name", ns) for author in entry.findall("atom:author", ns) ] categories = [ cat.attrib.get("term", "") for cat in entry.findall("atom:category", ns) if cat.attrib.get("term") ] snippet = trim_text(summary) if snippet: chunks.append( ResearchChunk( source="arxiv", tool="search_arxiv", title=html.unescape(title), url=url, text=html.unescape(snippet), metadata={ "published": published, "authors": [a for a in authors if a], "categories": categories, }, ) ) ranked = rank_chunks_for_query(query, intent, chunks, MAX_RETURNED_CHUNKS) return ResearchResult("search_arxiv", query, ranked, raw_count=len(chunks)) except Exception as exc: return ResearchResult("search_arxiv", query, error=str(exc)) async def search_scholar(query: str, intent: str = "") -> ResearchResult: try: async with httpx.AsyncClient(timeout=15.0) as client: resp = await client.get( "https://api.semanticscholar.org/graph/v1/paper/search", params={ "query": query, "limit": MAX_SOURCE_RESULTS, "fields": "title,abstract,url,year,citationCount,venue,authors", }, headers={"User-Agent": "ExplainerEnv/1.0"}, ) resp.raise_for_status() papers = resp.json().get("data", []) chunks: list[ResearchChunk] = [] for paper in papers: abstract = paper.get("abstract") or "" snippet = trim_text(abstract) if not snippet: continue chunks.append( ResearchChunk( source="semantic_scholar", tool="search_scholar", title=paper.get("title", "Untitled"), url=paper.get("url", ""), text=snippet, metadata={ "year": paper.get("year"), "venue": paper.get("venue"), "citation_count": paper.get("citationCount", 0), }, ) ) ranked = rank_chunks_for_query(query, intent, chunks, MAX_RETURNED_CHUNKS) return ResearchResult("search_scholar", query, ranked, raw_count=len(chunks)) except Exception as exc: return ResearchResult("search_scholar", query, error=str(exc)) async def fetch_docs(query: str, intent: str = "") -> ResearchResult: urls = _select_doc_urls(query) if not urls: return ResearchResult("fetch_docs", query, error="No allowed documentation target matched query") chunks: list[ResearchChunk] = [] async with httpx.AsyncClient(timeout=15.0, follow_redirects=True) as client: for title, url in urls[:MAX_SOURCE_RESULTS]: try: resp = await client.get(url, headers={"User-Agent": "ExplainerEnv/1.0"}) resp.raise_for_status() text = _extract_page_text(resp.text) except Exception: continue for section_title, body in chunk_markdown(text, title): snippet = trim_text(body) if snippet: chunks.append( ResearchChunk( source="docs", tool="fetch_docs", title=f"{title} - {section_title}", url=url, text=snippet, metadata={"doc": title}, ) ) ranked = rank_chunks_for_query(query, intent, chunks, MAX_RETURNED_CHUNKS) return ResearchResult("fetch_docs", query, ranked, raw_count=len(chunks)) async def search_hf_hub(query: str, intent: str = "") -> ResearchResult: try: from huggingface_hub import HfApi except ImportError as exc: return ResearchResult("search_hf_hub", query, error=f"huggingface_hub unavailable: {exc}") def _load() -> list[ResearchChunk]: api = HfApi() chunks: list[ResearchChunk] = [] for model in api.list_models(search=query, limit=2): text = " ".join(str(x) for x in [model.modelId, model.pipeline_tag, model.tags] if x) chunks.append( ResearchChunk( source="hf_hub_model", tool="search_hf_hub", title=model.modelId, url=f"https://huggingface.co/{model.modelId}", text=trim_text(text), metadata={"downloads": getattr(model, "downloads", None)}, ) ) for dataset in api.list_datasets(search=query, limit=2): dataset_id = getattr(dataset, "id", "") text = " ".join(str(x) for x in [dataset_id, getattr(dataset, "tags", None)] if x) chunks.append( ResearchChunk( source="hf_hub_dataset", tool="search_hf_hub", title=dataset_id, url=f"https://huggingface.co/datasets/{dataset_id}", text=trim_text(text), metadata={"downloads": getattr(dataset, "downloads", None)}, ) ) return chunks try: chunks = await asyncio.to_thread(_load) ranked = rank_chunks_for_query(query, intent, chunks, MAX_RETURNED_CHUNKS) return ResearchResult("search_hf_hub", query, ranked, raw_count=len(chunks)) except Exception as exc: return ResearchResult("search_hf_hub", query, error=str(exc)) _TOOL_RUNNERS: dict[str, Callable[[str, str], Awaitable[ResearchResult]]] = { "search_wikipedia": search_wikipedia, "search_hf_papers": search_hf_papers, "search_arxiv": search_arxiv, "search_scholar": search_scholar, "fetch_docs": fetch_docs, "search_hf_hub": search_hf_hub, } _SKIP_SECTIONS = frozenset({ "references", "external links", "see also", "further reading", "notes", "citations", "bibliography", "sources", }) _DOC_URLS = { "marimo": [ ("marimo CLI", "https://docs.marimo.io/cli/"), ("marimo lint rules", "https://docs.marimo.io/guides/lint_rules/"), ("marimo duplicate definitions", "https://docs.marimo.io/guides/understanding_errors/multiple_definitions/"), ("marimo plotting", "https://docs.marimo.io/guides/working_with_data/plotting/"), ], "manim": [ ("Manim quickstart", "https://docs.manim.community/en/stable/tutorials/quickstart.html"), ("Manim reference", "https://docs.manim.community/en/stable/reference.html"), ], "numpy": [("NumPy user guide", "https://numpy.org/doc/stable/user/")], "scipy": [("SciPy user guide", "https://docs.scipy.org/doc/scipy/tutorial/")], "sklearn": [("scikit-learn user guide", "https://scikit-learn.org/stable/user_guide.html")], "pandas": [("pandas user guide", "https://pandas.pydata.org/docs/user_guide/")], "pytorch": [("PyTorch docs", "https://pytorch.org/docs/stable/index.html")], "plotly": [("Plotly Python docs", "https://plotly.com/python/")], } def _flatten_wikipedia_sections(sections, max_depth: int = 2, depth: int = 0) -> list[tuple[str, str]]: result: list[tuple[str, str]] = [] for section in sections: if section.title.lower() in _SKIP_SECTIONS: continue if section.text.strip(): result.append((section.title, section.text.strip())) if depth < max_depth and section.sections: result.extend(_flatten_wikipedia_sections(section.sections, max_depth, depth + 1)) return result def _xml_text(node, path: str, ns: dict[str, str]) -> str: found = node.find(path, ns) return found.text.strip() if found is not None and found.text else "" def _select_doc_urls(query: str) -> list[tuple[str, str]]: lower = query.lower() selected: list[tuple[str, str]] = [] for key, urls in _DOC_URLS.items(): if key in lower or key.replace("sklearn", "scikit-learn") in lower: selected.extend(urls) if not selected and any(term in lower for term in ("notebook", "animation", "plot", "chart", "lint")): selected.extend(_DOC_URLS["marimo"]) selected.extend(_DOC_URLS["manim"]) selected.extend(_DOC_URLS["plotly"]) return selected def _extract_page_text(markup: str) -> str: try: import trafilatura extracted = trafilatura.extract(markup, output_format="markdown") if extracted: return extracted except Exception: pass text = re.sub(r"<(script|style).*?", " ", markup, flags=re.DOTALL | re.IGNORECASE) text = re.sub(r"<[^>]+>", " ", text) return html.unescape(re.sub(r"\s+", " ", text))