recon / src /agents /retriever.py
MukulRay's picture
Phase 2.2-2.3: add doi to Paper, populate from S2 externalIds, integrate OpenAlex into retriever
ad651e3
import logging
import time
from dotenv import load_dotenv
from src.state import ResearchState, Paper, WebResult
from src.retriever_utils import (
search_semantic_scholar,
search_web,
build_citation_graph,
hybrid_score,
)
load_dotenv()
logger = logging.getLogger(__name__)
# How many results to fetch per sub-question per source
S2_LIMIT_PER_QUESTION = 4
WEB_LIMIT_PER_QUESTION = 3
def _to_search_query(question: str) -> str:
"""
Convert a natural language question to a short keyword query for S2.
Strips question words and keeps the core noun phrases.
"""
import re
# Remove question words and common filler
stopwords = [
"what are", "what is", "how does", "how do", "why is", "why are",
"when did", "where is", "which are", "tell me about",
"foundational papers on", "recent advances in", "open challenges in",
"the current state of", "published in", "for llms", "in llms",
"papers on", "research on", "advances in", "challenges in",
"were", "was", "the", "a ", "an ", "in ", "of ", "for ", "on ",
]
q = question.lower().strip().rstrip("?")
for sw in stopwords:
q = q.replace(sw, " ")
# Collapse whitespace
q = re.sub(r"\s+", " ", q).strip()
# Cap at 6 words
words = q.split()[:6]
return " ".join(words)
# ---------------------------------------------------------------------------
# Retriever node — called by LangGraph
# ---------------------------------------------------------------------------
def retriever_node(state: ResearchState) -> ResearchState:
"""
Reads: sub_questions, decay_config
Writes: retrieved_papers, web_results, citation_graph
"""
sub_questions = state.get("sub_questions") or []
decay_config = state.get("decay_config", "linear")
if not sub_questions:
logger.warning("Retriever received no sub-questions — returning empty")
return {
**state,
"retrieved_papers": [],
"web_results": [],
"citation_graph": {},
}
all_papers: list[Paper] = []
all_web: list[WebResult] = []
seen_paper_ids: set[str] = set()
seen_urls: set[str] = set()
for i, question in enumerate(sub_questions):
logger.info(f"Retriever: searching for sub-question {i+1}: {question[:60]}")
# --- Semantic Scholar ---
s2_query = _to_search_query(question)
logger.info(f" S2 keyword query: '{s2_query}'")
papers = search_semantic_scholar(
s2_query,
limit=S2_LIMIT_PER_QUESTION,
decay_config=decay_config,
)
for p in papers:
if p.paper_id and p.paper_id not in seen_paper_ids:
seen_paper_ids.add(p.paper_id)
all_papers.append(p)
elif not p.paper_id:
all_papers.append(p)
# --- Web search ---
web_results = search_web(question, limit=WEB_LIMIT_PER_QUESTION)
for r in web_results:
if r.url and r.url not in seen_urls:
seen_urls.add(r.url)
all_web.append(r)
# Small pause between sub-questions to be gentle on APIs
if i < len(sub_questions) - 1:
time.sleep(1)
# --- Phase 2.2: OpenAlex augmentation ---
from src.openalex_utils import search_openalex
existing_dois = {p.doi.lower() for p in all_papers if p.doi}
existing_ids = {p.paper_id for p in all_papers}
if len(all_papers) < 12:
for question in sub_questions[:2]: # only first 2 sub-questions
try:
oa_results = search_openalex(question, max_results=3)
for r in oa_results:
doi_lower = (r.get("doi") or "").lower()
pid = r.get("paper_id") or ""
# Skip if we already have this paper by DOI or paper_id
if doi_lower and doi_lower in existing_dois:
continue
if pid in existing_ids:
continue
# Build Paper object from OpenAlex result
p = Paper(
title=r.get("title") or "",
abstract=r.get("abstract") or "",
year=r.get("year") or 0,
citation_count=r.get("citation_count") or 0,
paper_id=pid,
authors=[a.strip() for a in (r.get("authors") or "").split(",") if a.strip()],
references=[],
doi=r.get("doi") or "",
source="openalex",
)
if not p.title or not p.year:
continue
p.hybrid_score = hybrid_score(
semantic_sim=0.3, # conservative default — no query embedding for OA papers
year=p.year,
citation_count=p.citation_count,
decay_config=decay_config,
)
all_papers.append(p)
existing_dois.add(doi_lower)
existing_ids.add(pid)
time.sleep(0.5) # be polite to OpenAlex
except Exception as e:
logger.warning(f"OpenAlex augmentation failed for '{question[:40]}': {e}")
continue
# Sort papers by hybrid score descending
all_papers.sort(key=lambda p: p.hybrid_score, reverse=True)
# Build citation graph from retrieved papers
citation_graph = build_citation_graph(all_papers)
logger.info(
f"Retriever complete: {len(all_papers)} papers, "
f"{len(all_web)} web results, "
f"{sum(len(v) for v in citation_graph.values())} citation edges"
)
return {
**state,
"retrieved_papers": all_papers,
"web_results": all_web,
"citation_graph": citation_graph,
}