explainer-env / research /router.py
kgdrathan's picture
Upload folder using huggingface_hub
b12f1bd verified
"""Explicit research tools used by the explore phase."""
from __future__ import annotations
import asyncio
import html
import re
import xml.etree.ElementTree as ET
from typing import Awaitable, Callable
import httpx
import wikipediaapi
try:
from .. import constants as _constants
except ImportError: # pragma: no cover - supports direct test execution
import constants as _constants
from .retrieval import MAX_RETURNED_CHUNKS, chunk_markdown, rank_chunks_for_query, trim_text
from .types import ResearchChunk, ResearchResult
AVAILABLE_TOOLS = _constants.AVAILABLE_TOOLS
MAX_SOURCE_RESULTS = 5
async def run_research_tool(tool: str, query: str, intent: str = "") -> ResearchResult:
"""Run a named research tool and return structured chunks."""
tool = tool.strip()
query = query.strip()
if tool not in _TOOL_RUNNERS:
return ResearchResult(tool=tool or "(missing)", query=query, error="Unknown research tool")
if not query:
return ResearchResult(tool=tool, query=query, error="Empty query")
return await _TOOL_RUNNERS[tool](query, intent)
async def search_wikipedia(query: str, intent: str = "") -> ResearchResult:
try:
wiki = wikipediaapi.AsyncWikipedia(
user_agent="ExplainerEnv/1.0 (hackathon project)",
language="en",
)
search_results = await wiki.search(query, limit=MAX_SOURCE_RESULTS)
titles = list(search_results.pages) if search_results and search_results.pages else []
chunks: list[ResearchChunk] = []
for title in titles:
page = wiki.page(title)
if not await page.exists():
continue
summary = await page.summary
sections = await page.sections
docs: list[tuple[str, str]] = []
if summary:
docs.append((title, summary))
docs.extend(_flatten_wikipedia_sections(sections))
for section_title, text in docs:
snippet = trim_text(text)
if snippet:
chunks.append(
ResearchChunk(
source="wikipedia",
tool="search_wikipedia",
title=f"{title} - {section_title}",
url=f"https://en.wikipedia.org/wiki/{title.replace(' ', '_')}",
text=snippet,
metadata={"page": title},
)
)
ranked = rank_chunks_for_query(query, intent, chunks, MAX_RETURNED_CHUNKS)
return ResearchResult("search_wikipedia", query, ranked, raw_count=len(chunks))
except Exception as exc:
return ResearchResult("search_wikipedia", query, error=str(exc))
async def search_hf_papers(query: str, intent: str = "") -> ResearchResult:
try:
async with httpx.AsyncClient(timeout=15.0, follow_redirects=True) as client:
resp = await client.get(
"https://huggingface.co/api/papers/search",
params={"q": query, "limit": MAX_SOURCE_RESULTS},
headers={"User-Agent": "ExplainerEnv/1.0"},
)
resp.raise_for_status()
papers = resp.json() or []
chunks: list[ResearchChunk] = []
for paper in papers[:MAX_SOURCE_RESULTS]:
paper_id = paper.get("id", "")
title = paper.get("title", "Untitled")
url = f"https://huggingface.co/papers/{paper_id}" if paper_id else ""
text_chunks = [("Abstract", paper.get("summary", ""))]
if paper_id:
md_resp = await client.get(
f"https://huggingface.co/papers/{paper_id}.md",
headers={"User-Agent": "ExplainerEnv/1.0"},
)
if md_resp.status_code == 200 and md_resp.text.strip():
text_chunks = chunk_markdown(md_resp.text, "Abstract")
for section_title, text in text_chunks:
snippet = trim_text(text)
if snippet:
chunks.append(
ResearchChunk(
source="hf_papers",
tool="search_hf_papers",
title=f"{title} - {section_title}",
url=url,
text=snippet,
metadata={"paper_id": paper_id},
)
)
ranked = rank_chunks_for_query(query, intent, chunks, MAX_RETURNED_CHUNKS)
return ResearchResult("search_hf_papers", query, ranked, raw_count=len(chunks))
except Exception as exc:
return ResearchResult("search_hf_papers", query, error=str(exc))
async def search_arxiv(query: str, intent: str = "") -> ResearchResult:
try:
async with httpx.AsyncClient(timeout=15.0) as client:
resp = await client.get(
"https://export.arxiv.org/api/query",
params={
"search_query": f"all:{query}",
"start": 0,
"max_results": MAX_SOURCE_RESULTS,
"sortBy": "relevance",
"sortOrder": "descending",
},
headers={"User-Agent": "ExplainerEnv/1.0"},
)
resp.raise_for_status()
root = ET.fromstring(resp.text)
ns = {"atom": "http://www.w3.org/2005/Atom", "arxiv": "http://arxiv.org/schemas/atom"}
chunks: list[ResearchChunk] = []
for entry in root.findall("atom:entry", ns):
title = _xml_text(entry, "atom:title", ns) or "Untitled"
summary = _xml_text(entry, "atom:summary", ns)
url = _xml_text(entry, "atom:id", ns)
published = _xml_text(entry, "atom:published", ns)
authors = [
_xml_text(author, "atom:name", ns)
for author in entry.findall("atom:author", ns)
]
categories = [
cat.attrib.get("term", "")
for cat in entry.findall("atom:category", ns)
if cat.attrib.get("term")
]
snippet = trim_text(summary)
if snippet:
chunks.append(
ResearchChunk(
source="arxiv",
tool="search_arxiv",
title=html.unescape(title),
url=url,
text=html.unescape(snippet),
metadata={
"published": published,
"authors": [a for a in authors if a],
"categories": categories,
},
)
)
ranked = rank_chunks_for_query(query, intent, chunks, MAX_RETURNED_CHUNKS)
return ResearchResult("search_arxiv", query, ranked, raw_count=len(chunks))
except Exception as exc:
return ResearchResult("search_arxiv", query, error=str(exc))
async def search_scholar(query: str, intent: str = "") -> ResearchResult:
try:
async with httpx.AsyncClient(timeout=15.0) as client:
resp = await client.get(
"https://api.semanticscholar.org/graph/v1/paper/search",
params={
"query": query,
"limit": MAX_SOURCE_RESULTS,
"fields": "title,abstract,url,year,citationCount,venue,authors",
},
headers={"User-Agent": "ExplainerEnv/1.0"},
)
resp.raise_for_status()
papers = resp.json().get("data", [])
chunks: list[ResearchChunk] = []
for paper in papers:
abstract = paper.get("abstract") or ""
snippet = trim_text(abstract)
if not snippet:
continue
chunks.append(
ResearchChunk(
source="semantic_scholar",
tool="search_scholar",
title=paper.get("title", "Untitled"),
url=paper.get("url", ""),
text=snippet,
metadata={
"year": paper.get("year"),
"venue": paper.get("venue"),
"citation_count": paper.get("citationCount", 0),
},
)
)
ranked = rank_chunks_for_query(query, intent, chunks, MAX_RETURNED_CHUNKS)
return ResearchResult("search_scholar", query, ranked, raw_count=len(chunks))
except Exception as exc:
return ResearchResult("search_scholar", query, error=str(exc))
async def fetch_docs(query: str, intent: str = "") -> ResearchResult:
urls = _select_doc_urls(query)
if not urls:
return ResearchResult("fetch_docs", query, error="No allowed documentation target matched query")
chunks: list[ResearchChunk] = []
async with httpx.AsyncClient(timeout=15.0, follow_redirects=True) as client:
for title, url in urls[:MAX_SOURCE_RESULTS]:
try:
resp = await client.get(url, headers={"User-Agent": "ExplainerEnv/1.0"})
resp.raise_for_status()
text = _extract_page_text(resp.text)
except Exception:
continue
for section_title, body in chunk_markdown(text, title):
snippet = trim_text(body)
if snippet:
chunks.append(
ResearchChunk(
source="docs",
tool="fetch_docs",
title=f"{title} - {section_title}",
url=url,
text=snippet,
metadata={"doc": title},
)
)
ranked = rank_chunks_for_query(query, intent, chunks, MAX_RETURNED_CHUNKS)
return ResearchResult("fetch_docs", query, ranked, raw_count=len(chunks))
async def search_hf_hub(query: str, intent: str = "") -> ResearchResult:
try:
from huggingface_hub import HfApi
except ImportError as exc:
return ResearchResult("search_hf_hub", query, error=f"huggingface_hub unavailable: {exc}")
def _load() -> list[ResearchChunk]:
api = HfApi()
chunks: list[ResearchChunk] = []
for model in api.list_models(search=query, limit=2):
text = " ".join(str(x) for x in [model.modelId, model.pipeline_tag, model.tags] if x)
chunks.append(
ResearchChunk(
source="hf_hub_model",
tool="search_hf_hub",
title=model.modelId,
url=f"https://huggingface.co/{model.modelId}",
text=trim_text(text),
metadata={"downloads": getattr(model, "downloads", None)},
)
)
for dataset in api.list_datasets(search=query, limit=2):
dataset_id = getattr(dataset, "id", "")
text = " ".join(str(x) for x in [dataset_id, getattr(dataset, "tags", None)] if x)
chunks.append(
ResearchChunk(
source="hf_hub_dataset",
tool="search_hf_hub",
title=dataset_id,
url=f"https://huggingface.co/datasets/{dataset_id}",
text=trim_text(text),
metadata={"downloads": getattr(dataset, "downloads", None)},
)
)
return chunks
try:
chunks = await asyncio.to_thread(_load)
ranked = rank_chunks_for_query(query, intent, chunks, MAX_RETURNED_CHUNKS)
return ResearchResult("search_hf_hub", query, ranked, raw_count=len(chunks))
except Exception as exc:
return ResearchResult("search_hf_hub", query, error=str(exc))
_TOOL_RUNNERS: dict[str, Callable[[str, str], Awaitable[ResearchResult]]] = {
"search_wikipedia": search_wikipedia,
"search_hf_papers": search_hf_papers,
"search_arxiv": search_arxiv,
"search_scholar": search_scholar,
"fetch_docs": fetch_docs,
"search_hf_hub": search_hf_hub,
}
_SKIP_SECTIONS = frozenset({
"references",
"external links",
"see also",
"further reading",
"notes",
"citations",
"bibliography",
"sources",
})
_DOC_URLS = {
"marimo": [
("marimo CLI", "https://docs.marimo.io/cli/"),
("marimo lint rules", "https://docs.marimo.io/guides/lint_rules/"),
("marimo duplicate definitions", "https://docs.marimo.io/guides/understanding_errors/multiple_definitions/"),
("marimo plotting", "https://docs.marimo.io/guides/working_with_data/plotting/"),
],
"manim": [
("Manim quickstart", "https://docs.manim.community/en/stable/tutorials/quickstart.html"),
("Manim reference", "https://docs.manim.community/en/stable/reference.html"),
],
"numpy": [("NumPy user guide", "https://numpy.org/doc/stable/user/")],
"scipy": [("SciPy user guide", "https://docs.scipy.org/doc/scipy/tutorial/")],
"sklearn": [("scikit-learn user guide", "https://scikit-learn.org/stable/user_guide.html")],
"pandas": [("pandas user guide", "https://pandas.pydata.org/docs/user_guide/")],
"pytorch": [("PyTorch docs", "https://pytorch.org/docs/stable/index.html")],
"plotly": [("Plotly Python docs", "https://plotly.com/python/")],
}
def _flatten_wikipedia_sections(sections, max_depth: int = 2, depth: int = 0) -> list[tuple[str, str]]:
result: list[tuple[str, str]] = []
for section in sections:
if section.title.lower() in _SKIP_SECTIONS:
continue
if section.text.strip():
result.append((section.title, section.text.strip()))
if depth < max_depth and section.sections:
result.extend(_flatten_wikipedia_sections(section.sections, max_depth, depth + 1))
return result
def _xml_text(node, path: str, ns: dict[str, str]) -> str:
found = node.find(path, ns)
return found.text.strip() if found is not None and found.text else ""
def _select_doc_urls(query: str) -> list[tuple[str, str]]:
lower = query.lower()
selected: list[tuple[str, str]] = []
for key, urls in _DOC_URLS.items():
if key in lower or key.replace("sklearn", "scikit-learn") in lower:
selected.extend(urls)
if not selected and any(term in lower for term in ("notebook", "animation", "plot", "chart", "lint")):
selected.extend(_DOC_URLS["marimo"])
selected.extend(_DOC_URLS["manim"])
selected.extend(_DOC_URLS["plotly"])
return selected
def _extract_page_text(markup: str) -> str:
try:
import trafilatura
extracted = trafilatura.extract(markup, output_format="markdown")
if extracted:
return extracted
except Exception:
pass
text = re.sub(r"<(script|style).*?</\1>", " ", markup, flags=re.DOTALL | re.IGNORECASE)
text = re.sub(r"<[^>]+>", " ", text)
return html.unescape(re.sub(r"\s+", " ", text))