VibecoderMcSwaggins's picture
feat: Wire LlamaIndex RAG into Simple Mode (Tiered Embedding) (#83)
7baf8ba unverified
raw
history blame
12 kB
"""Graph node implementations for DeepBoner research."""
from typing import Any, Literal
import structlog
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage
from langchain_core.output_parsers import PydanticOutputParser
from langchain_core.prompts import ChatPromptTemplate
from pydantic import BaseModel, Field
from pydantic_ai import Agent
from src.agent_factory.judges import get_model
from src.agents.graph.state import Hypothesis, ResearchState
from src.prompts.hypothesis import SYSTEM_PROMPT as HYPOTHESIS_SYSTEM_PROMPT
from src.prompts.hypothesis import format_hypothesis_prompt
from src.prompts.report import SYSTEM_PROMPT as REPORT_SYSTEM_PROMPT
from src.prompts.report import format_report_prompt
from src.services.embedding_protocol import EmbeddingServiceProtocol
from src.tools.base import SearchTool
from src.tools.clinicaltrials import ClinicalTrialsTool
from src.tools.europepmc import EuropePMCTool
from src.tools.pubmed import PubMedTool
from src.tools.search_handler import SearchHandler
from src.utils.citation_validator import validate_references
from src.utils.models import (
Citation,
Evidence,
HypothesisAssessment,
MechanismHypothesis,
ResearchReport,
)
logger = structlog.get_logger()
def _convert_hypothesis_to_mechanism(h: Hypothesis) -> MechanismHypothesis:
"""Convert state Hypothesis to MechanismHypothesis for report generation.
The state Hypothesis stores the mechanism as a statement like:
"drug -> target -> pathway -> effect"
We parse this back into structured MechanismHypothesis fields.
"""
# Parse statement format: "drug -> target -> pathway -> effect"
# Handle both " -> " (standard) and "->" (compact) separators
separator = " -> " if " -> " in h.statement else "->"
parts = [p.strip() for p in h.statement.split(separator)]
# Validate: exactly 4 non-empty parts
if len(parts) == 4 and all(parts):
drug, target, pathway, effect = parts
elif len(parts) > 4 and all(parts[:4]):
# More than 4 parts: join extras into effect
drug, target, pathway = parts[0], parts[1], parts[2]
effect = f"{separator}".join(parts[3:])
logger.debug(
"Hypothesis has extra parts, joined into effect",
hypothesis_id=h.id,
parts_count=len(parts),
)
else:
# Log parsing failure for debugging
logger.warning(
"Failed to parse hypothesis statement format",
hypothesis_id=h.id,
statement=h.statement[:100], # Truncate for log safety
parts_count=len(parts),
)
# Use meaningful fallback values
drug = "Unknown"
target = "Unknown"
pathway = "Unknown"
effect = h.statement.strip() if h.statement else "Unknown effect"
return MechanismHypothesis(
drug=drug,
target=target,
pathway=pathway,
effect=effect,
confidence=h.confidence,
supporting_evidence=h.supporting_evidence_ids,
contradicting_evidence=h.contradicting_evidence_ids,
)
def _results_to_evidence(results: list[dict[str, Any]]) -> list[Evidence]:
"""Convert search_similar results to Evidence objects.
Extracted helper to avoid code duplication between judge_node and synthesize_node.
"""
evidence_list = []
for r in results:
meta = r.get("metadata", {})
authors_str = meta.get("authors", "")
author_list = [a.strip() for a in authors_str.split(",")] if authors_str else []
evidence_list.append(
Evidence(
content=r.get("content", ""),
citation=Citation(
url=r.get("id", ""),
title=meta.get("title", "Unknown"),
source=meta.get("source", "Unknown"),
date=meta.get("date", ""),
authors=author_list,
),
)
)
return evidence_list
# --- Supervisor Output Schema ---
class SupervisorDecision(BaseModel):
"""The decision made by the supervisor."""
next_step: Literal["search", "judge", "resolve", "synthesize", "finish"] = Field(
description="The next step to take in the research process."
)
reasoning: str = Field(description="Reasoning for this decision.")
# --- Nodes ---
async def search_node(
state: ResearchState, embedding_service: EmbeddingServiceProtocol | None = None
) -> dict[str, Any]:
"""Execute search across all sources."""
query = state["query"]
logger.info("search_node: executing search", query=query)
# Initialize tools
tools: list[SearchTool] = [PubMedTool(), ClinicalTrialsTool(), EuropePMCTool()]
handler = SearchHandler(tools=tools)
# Execute search
result = await handler.execute(query)
new_evidence_count = 0
new_ids = []
if embedding_service and result.evidence:
# Deduplicate and store (deduplicate() already calls add_evidence() internally)
unique_evidence = await embedding_service.deduplicate(result.evidence)
# Track IDs for state (evidence already stored by deduplicate())
new_ids = [ev.citation.url for ev in unique_evidence]
new_evidence_count = len(unique_evidence)
else:
new_evidence_count = len(result.evidence)
message = (
f"Search completed. Found {result.total_found} total, "
f"{new_evidence_count} unique new papers."
)
if result.errors:
message += f" Errors: {'; '.join(result.errors)}"
return {
"evidence_ids": new_ids,
"messages": [AIMessage(content=message)],
}
async def judge_node(
state: ResearchState, embedding_service: EmbeddingServiceProtocol | None = None
) -> dict[str, Any]:
"""Evaluate evidence and update hypothesis confidence."""
logger.info("judge_node: evaluating evidence")
evidence_context: list[Evidence] = []
if embedding_service:
scored_points = await embedding_service.search_similar(state["query"], n_results=20)
evidence_context = _results_to_evidence(scored_points)
agent = Agent(
model=get_model(),
output_type=HypothesisAssessment,
system_prompt=HYPOTHESIS_SYSTEM_PROMPT,
)
prompt = await format_hypothesis_prompt(
query=state["query"], evidence=evidence_context, embeddings=embedding_service
)
try:
result = await agent.run(prompt)
assessment = result.output
new_hypotheses = []
for h in assessment.hypotheses:
new_hypotheses.append(
Hypothesis(
id=h.drug,
statement=f"{h.drug} -> {h.target} -> {h.pathway} -> {h.effect}",
status="proposed",
confidence=h.confidence,
supporting_evidence_ids=[],
contradicting_evidence_ids=[],
)
)
return {
"hypotheses": new_hypotheses,
"messages": [AIMessage(content=f"Judge: Generated {len(new_hypotheses)} hypotheses.")],
"next_step": "resolve",
}
except Exception as e:
logger.error("judge_node failed", error=str(e))
return {"messages": [AIMessage(content=f"Judge Error: {e!s}")], "next_step": "search"}
async def resolve_node(
state: ResearchState, embedding_service: EmbeddingServiceProtocol | None = None
) -> dict[str, Any]:
"""Handle open conflicts."""
messages = []
# Access attributes with dot notation because items are Pydantic models
high_conf = [h for h in state["hypotheses"] if h.confidence > 0.8]
if high_conf:
messages.append(
AIMessage(
content=(
f"Resolver: Found {len(high_conf)} high confidence hypotheses. "
"Conflicts resolved."
)
)
)
else:
messages.append(AIMessage(content="Resolver: No high confidence hypotheses yet."))
return {"messages": messages}
async def synthesize_node(
state: ResearchState, embedding_service: EmbeddingServiceProtocol | None = None
) -> dict[str, Any]:
"""Generate final report."""
logger.info("synthesize_node: generating report")
evidence_context: list[Evidence] = []
if embedding_service:
scored_points = await embedding_service.search_similar(state["query"], n_results=50)
evidence_context = _results_to_evidence(scored_points)
agent = Agent(
model=get_model(),
output_type=ResearchReport,
system_prompt=REPORT_SYSTEM_PROMPT,
)
# Convert state hypotheses to MechanismHypothesis for report generation
mechanism_hypotheses = [_convert_hypothesis_to_mechanism(h) for h in state["hypotheses"]]
prompt = await format_report_prompt(
query=state["query"],
evidence=evidence_context,
hypotheses=mechanism_hypotheses,
assessment={},
metadata={"sources": list(set(e.citation.source for e in evidence_context))},
embeddings=embedding_service,
)
try:
result = await agent.run(prompt)
report = result.output
report = validate_references(report, evidence_context)
return {"messages": [AIMessage(content=report.to_markdown())], "next_step": "finish"}
except Exception as e:
logger.error("synthesize_node failed", error=str(e))
return {"messages": [AIMessage(content=f"Synthesis Error: {e!s}")], "next_step": "finish"}
async def supervisor_node(state: ResearchState, llm: BaseChatModel | None = None) -> dict[str, Any]:
"""Route to next node based on state using robust Pydantic parsing."""
if state["iteration_count"] >= state["max_iterations"]:
return {"next_step": "synthesize", "iteration_count": state["iteration_count"]}
if llm is None:
return {"next_step": "search", "iteration_count": state["iteration_count"] + 1}
parser = PydanticOutputParser(pydantic_object=SupervisorDecision)
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"You are the Research Supervisor. Manage the workflow.\n\n"
"State Summary:\n"
"- Query: {query}\n"
"- Hypotheses: {hypo_count}\n"
"- Conflicts: {conflict_count}\n"
"- Iteration: {iteration}/{max_iter}\n\n"
"Decide the next step based on this logic:\n"
"1. If there are open conflicts -> 'resolve'\n"
"2. If hypotheses are unverified or few -> 'search'\n"
"3. If new evidence needs evaluation -> 'judge'\n"
"4. If hypotheses are confirmed -> 'synthesize'\n\n"
"{format_instructions}",
),
("user", "What is the next step?"),
]
)
chain = prompt | llm | parser
try:
# Note: state["conflicts"] contains Pydantic models, so use dot notation
decision: SupervisorDecision = await chain.ainvoke(
{
"query": state["query"],
"hypo_count": len(state["hypotheses"]),
"conflict_count": len([c for c in state["conflicts"] if c.status == "open"]),
"iteration": state["iteration_count"],
"max_iter": state["max_iterations"],
"format_instructions": parser.get_format_instructions(),
}
)
return {
"next_step": decision.next_step,
"iteration_count": state["iteration_count"] + 1,
"messages": [AIMessage(content=f"Supervisor: {decision.reasoning}")],
}
except Exception as e:
return {
"next_step": "synthesize",
"iteration_count": state["iteration_count"] + 1,
"messages": [AIMessage(content=f"Supervisor Error: {e!s}. Proceeding to synthesis.")],
}