Spaces:
Sleeping
Sleeping
| """Explicit research tools used by the explore phase.""" | |
| from __future__ import annotations | |
| import asyncio | |
| import html | |
| import re | |
| import xml.etree.ElementTree as ET | |
| from typing import Awaitable, Callable | |
| import httpx | |
| import wikipediaapi | |
| try: | |
| from .. import constants as _constants | |
| except ImportError: # pragma: no cover - supports direct test execution | |
| import constants as _constants | |
| from .retrieval import MAX_RETURNED_CHUNKS, chunk_markdown, rank_chunks_for_query, trim_text | |
| from .types import ResearchChunk, ResearchResult | |
| AVAILABLE_TOOLS = _constants.AVAILABLE_TOOLS | |
| MAX_SOURCE_RESULTS = 5 | |
| async def run_research_tool(tool: str, query: str, intent: str = "") -> ResearchResult: | |
| """Run a named research tool and return structured chunks.""" | |
| tool = tool.strip() | |
| query = query.strip() | |
| if tool not in _TOOL_RUNNERS: | |
| return ResearchResult(tool=tool or "(missing)", query=query, error="Unknown research tool") | |
| if not query: | |
| return ResearchResult(tool=tool, query=query, error="Empty query") | |
| return await _TOOL_RUNNERS[tool](query, intent) | |
| async def search_wikipedia(query: str, intent: str = "") -> ResearchResult: | |
| try: | |
| wiki = wikipediaapi.AsyncWikipedia( | |
| user_agent="ExplainerEnv/1.0 (hackathon project)", | |
| language="en", | |
| ) | |
| search_results = await wiki.search(query, limit=MAX_SOURCE_RESULTS) | |
| titles = list(search_results.pages) if search_results and search_results.pages else [] | |
| chunks: list[ResearchChunk] = [] | |
| for title in titles: | |
| page = wiki.page(title) | |
| if not await page.exists(): | |
| continue | |
| summary = await page.summary | |
| sections = await page.sections | |
| docs: list[tuple[str, str]] = [] | |
| if summary: | |
| docs.append((title, summary)) | |
| docs.extend(_flatten_wikipedia_sections(sections)) | |
| for section_title, text in docs: | |
| snippet = trim_text(text) | |
| if snippet: | |
| chunks.append( | |
| ResearchChunk( | |
| source="wikipedia", | |
| tool="search_wikipedia", | |
| title=f"{title} - {section_title}", | |
| url=f"https://en.wikipedia.org/wiki/{title.replace(' ', '_')}", | |
| text=snippet, | |
| metadata={"page": title}, | |
| ) | |
| ) | |
| ranked = rank_chunks_for_query(query, intent, chunks, MAX_RETURNED_CHUNKS) | |
| return ResearchResult("search_wikipedia", query, ranked, raw_count=len(chunks)) | |
| except Exception as exc: | |
| return ResearchResult("search_wikipedia", query, error=str(exc)) | |
| async def search_hf_papers(query: str, intent: str = "") -> ResearchResult: | |
| try: | |
| async with httpx.AsyncClient(timeout=15.0, follow_redirects=True) as client: | |
| resp = await client.get( | |
| "https://huggingface.co/api/papers/search", | |
| params={"q": query, "limit": MAX_SOURCE_RESULTS}, | |
| headers={"User-Agent": "ExplainerEnv/1.0"}, | |
| ) | |
| resp.raise_for_status() | |
| papers = resp.json() or [] | |
| chunks: list[ResearchChunk] = [] | |
| for paper in papers[:MAX_SOURCE_RESULTS]: | |
| paper_id = paper.get("id", "") | |
| title = paper.get("title", "Untitled") | |
| url = f"https://huggingface.co/papers/{paper_id}" if paper_id else "" | |
| text_chunks = [("Abstract", paper.get("summary", ""))] | |
| if paper_id: | |
| md_resp = await client.get( | |
| f"https://huggingface.co/papers/{paper_id}.md", | |
| headers={"User-Agent": "ExplainerEnv/1.0"}, | |
| ) | |
| if md_resp.status_code == 200 and md_resp.text.strip(): | |
| text_chunks = chunk_markdown(md_resp.text, "Abstract") | |
| for section_title, text in text_chunks: | |
| snippet = trim_text(text) | |
| if snippet: | |
| chunks.append( | |
| ResearchChunk( | |
| source="hf_papers", | |
| tool="search_hf_papers", | |
| title=f"{title} - {section_title}", | |
| url=url, | |
| text=snippet, | |
| metadata={"paper_id": paper_id}, | |
| ) | |
| ) | |
| ranked = rank_chunks_for_query(query, intent, chunks, MAX_RETURNED_CHUNKS) | |
| return ResearchResult("search_hf_papers", query, ranked, raw_count=len(chunks)) | |
| except Exception as exc: | |
| return ResearchResult("search_hf_papers", query, error=str(exc)) | |
| async def search_arxiv(query: str, intent: str = "") -> ResearchResult: | |
| try: | |
| async with httpx.AsyncClient(timeout=15.0) as client: | |
| resp = await client.get( | |
| "https://export.arxiv.org/api/query", | |
| params={ | |
| "search_query": f"all:{query}", | |
| "start": 0, | |
| "max_results": MAX_SOURCE_RESULTS, | |
| "sortBy": "relevance", | |
| "sortOrder": "descending", | |
| }, | |
| headers={"User-Agent": "ExplainerEnv/1.0"}, | |
| ) | |
| resp.raise_for_status() | |
| root = ET.fromstring(resp.text) | |
| ns = {"atom": "http://www.w3.org/2005/Atom", "arxiv": "http://arxiv.org/schemas/atom"} | |
| chunks: list[ResearchChunk] = [] | |
| for entry in root.findall("atom:entry", ns): | |
| title = _xml_text(entry, "atom:title", ns) or "Untitled" | |
| summary = _xml_text(entry, "atom:summary", ns) | |
| url = _xml_text(entry, "atom:id", ns) | |
| published = _xml_text(entry, "atom:published", ns) | |
| authors = [ | |
| _xml_text(author, "atom:name", ns) | |
| for author in entry.findall("atom:author", ns) | |
| ] | |
| categories = [ | |
| cat.attrib.get("term", "") | |
| for cat in entry.findall("atom:category", ns) | |
| if cat.attrib.get("term") | |
| ] | |
| snippet = trim_text(summary) | |
| if snippet: | |
| chunks.append( | |
| ResearchChunk( | |
| source="arxiv", | |
| tool="search_arxiv", | |
| title=html.unescape(title), | |
| url=url, | |
| text=html.unescape(snippet), | |
| metadata={ | |
| "published": published, | |
| "authors": [a for a in authors if a], | |
| "categories": categories, | |
| }, | |
| ) | |
| ) | |
| ranked = rank_chunks_for_query(query, intent, chunks, MAX_RETURNED_CHUNKS) | |
| return ResearchResult("search_arxiv", query, ranked, raw_count=len(chunks)) | |
| except Exception as exc: | |
| return ResearchResult("search_arxiv", query, error=str(exc)) | |
| async def search_scholar(query: str, intent: str = "") -> ResearchResult: | |
| try: | |
| async with httpx.AsyncClient(timeout=15.0) as client: | |
| resp = await client.get( | |
| "https://api.semanticscholar.org/graph/v1/paper/search", | |
| params={ | |
| "query": query, | |
| "limit": MAX_SOURCE_RESULTS, | |
| "fields": "title,abstract,url,year,citationCount,venue,authors", | |
| }, | |
| headers={"User-Agent": "ExplainerEnv/1.0"}, | |
| ) | |
| resp.raise_for_status() | |
| papers = resp.json().get("data", []) | |
| chunks: list[ResearchChunk] = [] | |
| for paper in papers: | |
| abstract = paper.get("abstract") or "" | |
| snippet = trim_text(abstract) | |
| if not snippet: | |
| continue | |
| chunks.append( | |
| ResearchChunk( | |
| source="semantic_scholar", | |
| tool="search_scholar", | |
| title=paper.get("title", "Untitled"), | |
| url=paper.get("url", ""), | |
| text=snippet, | |
| metadata={ | |
| "year": paper.get("year"), | |
| "venue": paper.get("venue"), | |
| "citation_count": paper.get("citationCount", 0), | |
| }, | |
| ) | |
| ) | |
| ranked = rank_chunks_for_query(query, intent, chunks, MAX_RETURNED_CHUNKS) | |
| return ResearchResult("search_scholar", query, ranked, raw_count=len(chunks)) | |
| except Exception as exc: | |
| return ResearchResult("search_scholar", query, error=str(exc)) | |
| async def fetch_docs(query: str, intent: str = "") -> ResearchResult: | |
| urls = _select_doc_urls(query) | |
| if not urls: | |
| return ResearchResult("fetch_docs", query, error="No allowed documentation target matched query") | |
| chunks: list[ResearchChunk] = [] | |
| async with httpx.AsyncClient(timeout=15.0, follow_redirects=True) as client: | |
| for title, url in urls[:MAX_SOURCE_RESULTS]: | |
| try: | |
| resp = await client.get(url, headers={"User-Agent": "ExplainerEnv/1.0"}) | |
| resp.raise_for_status() | |
| text = _extract_page_text(resp.text) | |
| except Exception: | |
| continue | |
| for section_title, body in chunk_markdown(text, title): | |
| snippet = trim_text(body) | |
| if snippet: | |
| chunks.append( | |
| ResearchChunk( | |
| source="docs", | |
| tool="fetch_docs", | |
| title=f"{title} - {section_title}", | |
| url=url, | |
| text=snippet, | |
| metadata={"doc": title}, | |
| ) | |
| ) | |
| ranked = rank_chunks_for_query(query, intent, chunks, MAX_RETURNED_CHUNKS) | |
| return ResearchResult("fetch_docs", query, ranked, raw_count=len(chunks)) | |
| async def search_hf_hub(query: str, intent: str = "") -> ResearchResult: | |
| try: | |
| from huggingface_hub import HfApi | |
| except ImportError as exc: | |
| return ResearchResult("search_hf_hub", query, error=f"huggingface_hub unavailable: {exc}") | |
| def _load() -> list[ResearchChunk]: | |
| api = HfApi() | |
| chunks: list[ResearchChunk] = [] | |
| for model in api.list_models(search=query, limit=2): | |
| text = " ".join(str(x) for x in [model.modelId, model.pipeline_tag, model.tags] if x) | |
| chunks.append( | |
| ResearchChunk( | |
| source="hf_hub_model", | |
| tool="search_hf_hub", | |
| title=model.modelId, | |
| url=f"https://huggingface.co/{model.modelId}", | |
| text=trim_text(text), | |
| metadata={"downloads": getattr(model, "downloads", None)}, | |
| ) | |
| ) | |
| for dataset in api.list_datasets(search=query, limit=2): | |
| dataset_id = getattr(dataset, "id", "") | |
| text = " ".join(str(x) for x in [dataset_id, getattr(dataset, "tags", None)] if x) | |
| chunks.append( | |
| ResearchChunk( | |
| source="hf_hub_dataset", | |
| tool="search_hf_hub", | |
| title=dataset_id, | |
| url=f"https://huggingface.co/datasets/{dataset_id}", | |
| text=trim_text(text), | |
| metadata={"downloads": getattr(dataset, "downloads", None)}, | |
| ) | |
| ) | |
| return chunks | |
| try: | |
| chunks = await asyncio.to_thread(_load) | |
| ranked = rank_chunks_for_query(query, intent, chunks, MAX_RETURNED_CHUNKS) | |
| return ResearchResult("search_hf_hub", query, ranked, raw_count=len(chunks)) | |
| except Exception as exc: | |
| return ResearchResult("search_hf_hub", query, error=str(exc)) | |
| _TOOL_RUNNERS: dict[str, Callable[[str, str], Awaitable[ResearchResult]]] = { | |
| "search_wikipedia": search_wikipedia, | |
| "search_hf_papers": search_hf_papers, | |
| "search_arxiv": search_arxiv, | |
| "search_scholar": search_scholar, | |
| "fetch_docs": fetch_docs, | |
| "search_hf_hub": search_hf_hub, | |
| } | |
| _SKIP_SECTIONS = frozenset({ | |
| "references", | |
| "external links", | |
| "see also", | |
| "further reading", | |
| "notes", | |
| "citations", | |
| "bibliography", | |
| "sources", | |
| }) | |
| _DOC_URLS = { | |
| "marimo": [ | |
| ("marimo CLI", "https://docs.marimo.io/cli/"), | |
| ("marimo lint rules", "https://docs.marimo.io/guides/lint_rules/"), | |
| ("marimo duplicate definitions", "https://docs.marimo.io/guides/understanding_errors/multiple_definitions/"), | |
| ("marimo plotting", "https://docs.marimo.io/guides/working_with_data/plotting/"), | |
| ], | |
| "manim": [ | |
| ("Manim quickstart", "https://docs.manim.community/en/stable/tutorials/quickstart.html"), | |
| ("Manim reference", "https://docs.manim.community/en/stable/reference.html"), | |
| ], | |
| "numpy": [("NumPy user guide", "https://numpy.org/doc/stable/user/")], | |
| "scipy": [("SciPy user guide", "https://docs.scipy.org/doc/scipy/tutorial/")], | |
| "sklearn": [("scikit-learn user guide", "https://scikit-learn.org/stable/user_guide.html")], | |
| "pandas": [("pandas user guide", "https://pandas.pydata.org/docs/user_guide/")], | |
| "pytorch": [("PyTorch docs", "https://pytorch.org/docs/stable/index.html")], | |
| "plotly": [("Plotly Python docs", "https://plotly.com/python/")], | |
| } | |
| def _flatten_wikipedia_sections(sections, max_depth: int = 2, depth: int = 0) -> list[tuple[str, str]]: | |
| result: list[tuple[str, str]] = [] | |
| for section in sections: | |
| if section.title.lower() in _SKIP_SECTIONS: | |
| continue | |
| if section.text.strip(): | |
| result.append((section.title, section.text.strip())) | |
| if depth < max_depth and section.sections: | |
| result.extend(_flatten_wikipedia_sections(section.sections, max_depth, depth + 1)) | |
| return result | |
| def _xml_text(node, path: str, ns: dict[str, str]) -> str: | |
| found = node.find(path, ns) | |
| return found.text.strip() if found is not None and found.text else "" | |
| def _select_doc_urls(query: str) -> list[tuple[str, str]]: | |
| lower = query.lower() | |
| selected: list[tuple[str, str]] = [] | |
| for key, urls in _DOC_URLS.items(): | |
| if key in lower or key.replace("sklearn", "scikit-learn") in lower: | |
| selected.extend(urls) | |
| if not selected and any(term in lower for term in ("notebook", "animation", "plot", "chart", "lint")): | |
| selected.extend(_DOC_URLS["marimo"]) | |
| selected.extend(_DOC_URLS["manim"]) | |
| selected.extend(_DOC_URLS["plotly"]) | |
| return selected | |
| def _extract_page_text(markup: str) -> str: | |
| try: | |
| import trafilatura | |
| extracted = trafilatura.extract(markup, output_format="markdown") | |
| if extracted: | |
| return extracted | |
| except Exception: | |
| pass | |
| text = re.sub(r"<(script|style).*?</\1>", " ", markup, flags=re.DOTALL | re.IGNORECASE) | |
| text = re.sub(r"<[^>]+>", " ", text) | |
| return html.unescape(re.sub(r"\s+", " ", text)) | |