|
|
"""Graph node implementations for DeepBoner research.""" |
|
|
|
|
|
from typing import Any, Literal |
|
|
|
|
|
import structlog |
|
|
from langchain_core.language_models.chat_models import BaseChatModel |
|
|
from langchain_core.messages import AIMessage |
|
|
from langchain_core.output_parsers import PydanticOutputParser |
|
|
from langchain_core.prompts import ChatPromptTemplate |
|
|
from pydantic import BaseModel, Field |
|
|
from pydantic_ai import Agent |
|
|
|
|
|
from src.agent_factory.judges import get_model |
|
|
from src.agents.graph.state import Hypothesis, ResearchState |
|
|
from src.prompts.hypothesis import SYSTEM_PROMPT as HYPOTHESIS_SYSTEM_PROMPT |
|
|
from src.prompts.hypothesis import format_hypothesis_prompt |
|
|
from src.prompts.report import SYSTEM_PROMPT as REPORT_SYSTEM_PROMPT |
|
|
from src.prompts.report import format_report_prompt |
|
|
from src.services.embedding_protocol import EmbeddingServiceProtocol |
|
|
from src.tools.base import SearchTool |
|
|
from src.tools.clinicaltrials import ClinicalTrialsTool |
|
|
from src.tools.europepmc import EuropePMCTool |
|
|
from src.tools.pubmed import PubMedTool |
|
|
from src.tools.search_handler import SearchHandler |
|
|
from src.utils.citation_validator import validate_references |
|
|
from src.utils.models import ( |
|
|
Citation, |
|
|
Evidence, |
|
|
HypothesisAssessment, |
|
|
MechanismHypothesis, |
|
|
ResearchReport, |
|
|
) |
|
|
|
|
|
logger = structlog.get_logger() |
|
|
|
|
|
|
|
|
def _convert_hypothesis_to_mechanism(h: Hypothesis) -> MechanismHypothesis: |
|
|
"""Convert state Hypothesis to MechanismHypothesis for report generation. |
|
|
|
|
|
The state Hypothesis stores the mechanism as a statement like: |
|
|
"drug -> target -> pathway -> effect" |
|
|
|
|
|
We parse this back into structured MechanismHypothesis fields. |
|
|
""" |
|
|
|
|
|
|
|
|
separator = " -> " if " -> " in h.statement else "->" |
|
|
parts = [p.strip() for p in h.statement.split(separator)] |
|
|
|
|
|
|
|
|
if len(parts) == 4 and all(parts): |
|
|
drug, target, pathway, effect = parts |
|
|
elif len(parts) > 4 and all(parts[:4]): |
|
|
|
|
|
drug, target, pathway = parts[0], parts[1], parts[2] |
|
|
effect = f"{separator}".join(parts[3:]) |
|
|
logger.debug( |
|
|
"Hypothesis has extra parts, joined into effect", |
|
|
hypothesis_id=h.id, |
|
|
parts_count=len(parts), |
|
|
) |
|
|
else: |
|
|
|
|
|
logger.warning( |
|
|
"Failed to parse hypothesis statement format", |
|
|
hypothesis_id=h.id, |
|
|
statement=h.statement[:100], |
|
|
parts_count=len(parts), |
|
|
) |
|
|
|
|
|
drug = "Unknown" |
|
|
target = "Unknown" |
|
|
pathway = "Unknown" |
|
|
effect = h.statement.strip() if h.statement else "Unknown effect" |
|
|
|
|
|
return MechanismHypothesis( |
|
|
drug=drug, |
|
|
target=target, |
|
|
pathway=pathway, |
|
|
effect=effect, |
|
|
confidence=h.confidence, |
|
|
supporting_evidence=h.supporting_evidence_ids, |
|
|
contradicting_evidence=h.contradicting_evidence_ids, |
|
|
) |
|
|
|
|
|
|
|
|
def _results_to_evidence(results: list[dict[str, Any]]) -> list[Evidence]: |
|
|
"""Convert search_similar results to Evidence objects. |
|
|
|
|
|
Extracted helper to avoid code duplication between judge_node and synthesize_node. |
|
|
""" |
|
|
evidence_list = [] |
|
|
for r in results: |
|
|
meta = r.get("metadata", {}) |
|
|
authors_str = meta.get("authors", "") |
|
|
author_list = [a.strip() for a in authors_str.split(",")] if authors_str else [] |
|
|
evidence_list.append( |
|
|
Evidence( |
|
|
content=r.get("content", ""), |
|
|
citation=Citation( |
|
|
url=r.get("id", ""), |
|
|
title=meta.get("title", "Unknown"), |
|
|
source=meta.get("source", "Unknown"), |
|
|
date=meta.get("date", ""), |
|
|
authors=author_list, |
|
|
), |
|
|
) |
|
|
) |
|
|
return evidence_list |
|
|
|
|
|
|
|
|
|
|
|
class SupervisorDecision(BaseModel): |
|
|
"""The decision made by the supervisor.""" |
|
|
|
|
|
next_step: Literal["search", "judge", "resolve", "synthesize", "finish"] = Field( |
|
|
description="The next step to take in the research process." |
|
|
) |
|
|
reasoning: str = Field(description="Reasoning for this decision.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def search_node( |
|
|
state: ResearchState, embedding_service: EmbeddingServiceProtocol | None = None |
|
|
) -> dict[str, Any]: |
|
|
"""Execute search across all sources.""" |
|
|
query = state["query"] |
|
|
logger.info("search_node: executing search", query=query) |
|
|
|
|
|
|
|
|
tools: list[SearchTool] = [PubMedTool(), ClinicalTrialsTool(), EuropePMCTool()] |
|
|
handler = SearchHandler(tools=tools) |
|
|
|
|
|
|
|
|
result = await handler.execute(query) |
|
|
|
|
|
new_evidence_count = 0 |
|
|
new_ids = [] |
|
|
|
|
|
if embedding_service and result.evidence: |
|
|
|
|
|
unique_evidence = await embedding_service.deduplicate(result.evidence) |
|
|
|
|
|
|
|
|
new_ids = [ev.citation.url for ev in unique_evidence] |
|
|
new_evidence_count = len(unique_evidence) |
|
|
else: |
|
|
new_evidence_count = len(result.evidence) |
|
|
|
|
|
message = ( |
|
|
f"Search completed. Found {result.total_found} total, " |
|
|
f"{new_evidence_count} unique new papers." |
|
|
) |
|
|
if result.errors: |
|
|
message += f" Errors: {'; '.join(result.errors)}" |
|
|
|
|
|
return { |
|
|
"evidence_ids": new_ids, |
|
|
"messages": [AIMessage(content=message)], |
|
|
} |
|
|
|
|
|
|
|
|
async def judge_node( |
|
|
state: ResearchState, embedding_service: EmbeddingServiceProtocol | None = None |
|
|
) -> dict[str, Any]: |
|
|
"""Evaluate evidence and update hypothesis confidence.""" |
|
|
logger.info("judge_node: evaluating evidence") |
|
|
|
|
|
evidence_context: list[Evidence] = [] |
|
|
if embedding_service: |
|
|
scored_points = await embedding_service.search_similar(state["query"], n_results=20) |
|
|
evidence_context = _results_to_evidence(scored_points) |
|
|
|
|
|
agent = Agent( |
|
|
model=get_model(), |
|
|
output_type=HypothesisAssessment, |
|
|
system_prompt=HYPOTHESIS_SYSTEM_PROMPT, |
|
|
) |
|
|
|
|
|
prompt = await format_hypothesis_prompt( |
|
|
query=state["query"], evidence=evidence_context, embeddings=embedding_service |
|
|
) |
|
|
|
|
|
try: |
|
|
result = await agent.run(prompt) |
|
|
assessment = result.output |
|
|
|
|
|
new_hypotheses = [] |
|
|
for h in assessment.hypotheses: |
|
|
new_hypotheses.append( |
|
|
Hypothesis( |
|
|
id=h.drug, |
|
|
statement=f"{h.drug} -> {h.target} -> {h.pathway} -> {h.effect}", |
|
|
status="proposed", |
|
|
confidence=h.confidence, |
|
|
supporting_evidence_ids=[], |
|
|
contradicting_evidence_ids=[], |
|
|
) |
|
|
) |
|
|
|
|
|
return { |
|
|
"hypotheses": new_hypotheses, |
|
|
"messages": [AIMessage(content=f"Judge: Generated {len(new_hypotheses)} hypotheses.")], |
|
|
"next_step": "resolve", |
|
|
} |
|
|
except Exception as e: |
|
|
logger.error("judge_node failed", error=str(e)) |
|
|
return {"messages": [AIMessage(content=f"Judge Error: {e!s}")], "next_step": "search"} |
|
|
|
|
|
|
|
|
async def resolve_node( |
|
|
state: ResearchState, embedding_service: EmbeddingServiceProtocol | None = None |
|
|
) -> dict[str, Any]: |
|
|
"""Handle open conflicts.""" |
|
|
messages = [] |
|
|
|
|
|
|
|
|
high_conf = [h for h in state["hypotheses"] if h.confidence > 0.8] |
|
|
|
|
|
if high_conf: |
|
|
messages.append( |
|
|
AIMessage( |
|
|
content=( |
|
|
f"Resolver: Found {len(high_conf)} high confidence hypotheses. " |
|
|
"Conflicts resolved." |
|
|
) |
|
|
) |
|
|
) |
|
|
else: |
|
|
messages.append(AIMessage(content="Resolver: No high confidence hypotheses yet.")) |
|
|
|
|
|
return {"messages": messages} |
|
|
|
|
|
|
|
|
async def synthesize_node( |
|
|
state: ResearchState, embedding_service: EmbeddingServiceProtocol | None = None |
|
|
) -> dict[str, Any]: |
|
|
"""Generate final report.""" |
|
|
logger.info("synthesize_node: generating report") |
|
|
|
|
|
evidence_context: list[Evidence] = [] |
|
|
if embedding_service: |
|
|
scored_points = await embedding_service.search_similar(state["query"], n_results=50) |
|
|
evidence_context = _results_to_evidence(scored_points) |
|
|
|
|
|
agent = Agent( |
|
|
model=get_model(), |
|
|
output_type=ResearchReport, |
|
|
system_prompt=REPORT_SYSTEM_PROMPT, |
|
|
) |
|
|
|
|
|
|
|
|
mechanism_hypotheses = [_convert_hypothesis_to_mechanism(h) for h in state["hypotheses"]] |
|
|
|
|
|
prompt = await format_report_prompt( |
|
|
query=state["query"], |
|
|
evidence=evidence_context, |
|
|
hypotheses=mechanism_hypotheses, |
|
|
assessment={}, |
|
|
metadata={"sources": list(set(e.citation.source for e in evidence_context))}, |
|
|
embeddings=embedding_service, |
|
|
) |
|
|
|
|
|
try: |
|
|
result = await agent.run(prompt) |
|
|
report = result.output |
|
|
report = validate_references(report, evidence_context) |
|
|
|
|
|
return {"messages": [AIMessage(content=report.to_markdown())], "next_step": "finish"} |
|
|
except Exception as e: |
|
|
logger.error("synthesize_node failed", error=str(e)) |
|
|
return {"messages": [AIMessage(content=f"Synthesis Error: {e!s}")], "next_step": "finish"} |
|
|
|
|
|
|
|
|
async def supervisor_node(state: ResearchState, llm: BaseChatModel | None = None) -> dict[str, Any]: |
|
|
"""Route to next node based on state using robust Pydantic parsing.""" |
|
|
if state["iteration_count"] >= state["max_iterations"]: |
|
|
return {"next_step": "synthesize", "iteration_count": state["iteration_count"]} |
|
|
|
|
|
if llm is None: |
|
|
return {"next_step": "search", "iteration_count": state["iteration_count"] + 1} |
|
|
|
|
|
parser = PydanticOutputParser(pydantic_object=SupervisorDecision) |
|
|
|
|
|
prompt = ChatPromptTemplate.from_messages( |
|
|
[ |
|
|
( |
|
|
"system", |
|
|
"You are the Research Supervisor. Manage the workflow.\n\n" |
|
|
"State Summary:\n" |
|
|
"- Query: {query}\n" |
|
|
"- Hypotheses: {hypo_count}\n" |
|
|
"- Conflicts: {conflict_count}\n" |
|
|
"- Iteration: {iteration}/{max_iter}\n\n" |
|
|
"Decide the next step based on this logic:\n" |
|
|
"1. If there are open conflicts -> 'resolve'\n" |
|
|
"2. If hypotheses are unverified or few -> 'search'\n" |
|
|
"3. If new evidence needs evaluation -> 'judge'\n" |
|
|
"4. If hypotheses are confirmed -> 'synthesize'\n\n" |
|
|
"{format_instructions}", |
|
|
), |
|
|
("user", "What is the next step?"), |
|
|
] |
|
|
) |
|
|
|
|
|
chain = prompt | llm | parser |
|
|
|
|
|
try: |
|
|
|
|
|
decision: SupervisorDecision = await chain.ainvoke( |
|
|
{ |
|
|
"query": state["query"], |
|
|
"hypo_count": len(state["hypotheses"]), |
|
|
"conflict_count": len([c for c in state["conflicts"] if c.status == "open"]), |
|
|
"iteration": state["iteration_count"], |
|
|
"max_iter": state["max_iterations"], |
|
|
"format_instructions": parser.get_format_instructions(), |
|
|
} |
|
|
) |
|
|
return { |
|
|
"next_step": decision.next_step, |
|
|
"iteration_count": state["iteration_count"] + 1, |
|
|
"messages": [AIMessage(content=f"Supervisor: {decision.reasoning}")], |
|
|
} |
|
|
except Exception as e: |
|
|
return { |
|
|
"next_step": "synthesize", |
|
|
"iteration_count": state["iteration_count"] + 1, |
|
|
"messages": [AIMessage(content=f"Supervisor Error: {e!s}. Proceeding to synthesis.")], |
|
|
} |
|
|
|