| """Graph node implementations for DeepBoner research.""" |
|
|
| from typing import Any, Literal |
|
|
| import structlog |
| from langchain_core.language_models.chat_models import BaseChatModel |
| from langchain_core.messages import AIMessage |
| from langchain_core.output_parsers import PydanticOutputParser |
| from langchain_core.prompts import ChatPromptTemplate |
| from pydantic import BaseModel, Field |
| from pydantic_ai import Agent |
|
|
| from src.agent_factory.judges import get_model |
| from src.agents.graph.state import Hypothesis, ResearchState |
| from src.prompts.hypothesis import SYSTEM_PROMPT as HYPOTHESIS_SYSTEM_PROMPT |
| from src.prompts.hypothesis import format_hypothesis_prompt |
| from src.prompts.report import SYSTEM_PROMPT as REPORT_SYSTEM_PROMPT |
| from src.prompts.report import format_report_prompt |
| from src.services.embedding_protocol import EmbeddingServiceProtocol |
| from src.tools.base import SearchTool |
| from src.tools.clinicaltrials import ClinicalTrialsTool |
| from src.tools.europepmc import EuropePMCTool |
| from src.tools.pubmed import PubMedTool |
| from src.tools.search_handler import SearchHandler |
| from src.utils.citation_validator import validate_references |
| from src.utils.models import ( |
| Citation, |
| Evidence, |
| HypothesisAssessment, |
| MechanismHypothesis, |
| ResearchReport, |
| ) |
|
|
| logger = structlog.get_logger() |
|
|
|
|
| def _convert_hypothesis_to_mechanism(h: Hypothesis) -> MechanismHypothesis: |
| """Convert state Hypothesis to MechanismHypothesis for report generation. |
| |
| The state Hypothesis stores the mechanism as a statement like: |
| "drug -> target -> pathway -> effect" |
| |
| We parse this back into structured MechanismHypothesis fields. |
| """ |
| |
| |
| separator = " -> " if " -> " in h.statement else "->" |
| parts = [p.strip() for p in h.statement.split(separator)] |
|
|
| |
| if len(parts) == 4 and all(parts): |
| drug, target, pathway, effect = parts |
| elif len(parts) > 4 and all(parts[:4]): |
| |
| drug, target, pathway = parts[0], parts[1], parts[2] |
| effect = f"{separator}".join(parts[3:]) |
| logger.debug( |
| "Hypothesis has extra parts, joined into effect", |
| hypothesis_id=h.id, |
| parts_count=len(parts), |
| ) |
| else: |
| |
| logger.warning( |
| "Failed to parse hypothesis statement format", |
| hypothesis_id=h.id, |
| statement=h.statement[:100], |
| parts_count=len(parts), |
| ) |
| |
| drug = "Unknown" |
| target = "Unknown" |
| pathway = "Unknown" |
| effect = h.statement.strip() if h.statement else "Unknown effect" |
|
|
| return MechanismHypothesis( |
| drug=drug, |
| target=target, |
| pathway=pathway, |
| effect=effect, |
| confidence=h.confidence, |
| supporting_evidence=h.supporting_evidence_ids, |
| contradicting_evidence=h.contradicting_evidence_ids, |
| ) |
|
|
|
|
| def _results_to_evidence(results: list[dict[str, Any]]) -> list[Evidence]: |
| """Convert search_similar results to Evidence objects. |
| |
| Extracted helper to avoid code duplication between judge_node and synthesize_node. |
| """ |
| evidence_list = [] |
| for r in results: |
| meta = r.get("metadata", {}) |
| authors_str = meta.get("authors", "") |
| author_list = [a.strip() for a in authors_str.split(",")] if authors_str else [] |
| evidence_list.append( |
| Evidence( |
| content=r.get("content", ""), |
| citation=Citation( |
| url=r.get("id", ""), |
| title=meta.get("title", "Unknown"), |
| source=meta.get("source", "Unknown"), |
| date=meta.get("date", ""), |
| authors=author_list, |
| ), |
| ) |
| ) |
| return evidence_list |
|
|
|
|
| |
| class SupervisorDecision(BaseModel): |
| """The decision made by the supervisor.""" |
|
|
| next_step: Literal["search", "judge", "resolve", "synthesize", "finish"] = Field( |
| description="The next step to take in the research process." |
| ) |
| reasoning: str = Field(description="Reasoning for this decision.") |
|
|
|
|
| |
|
|
|
|
| async def search_node( |
| state: ResearchState, embedding_service: EmbeddingServiceProtocol | None = None |
| ) -> dict[str, Any]: |
| """Execute search across all sources.""" |
| query = state["query"] |
| logger.info("search_node: executing search", query=query) |
|
|
| |
| tools: list[SearchTool] = [PubMedTool(), ClinicalTrialsTool(), EuropePMCTool()] |
| handler = SearchHandler(tools=tools) |
|
|
| |
| result = await handler.execute(query) |
|
|
| new_evidence_count = 0 |
| new_ids = [] |
|
|
| if embedding_service and result.evidence: |
| |
| unique_evidence = await embedding_service.deduplicate(result.evidence) |
|
|
| |
| new_ids = [ev.citation.url for ev in unique_evidence] |
| new_evidence_count = len(unique_evidence) |
| else: |
| new_evidence_count = len(result.evidence) |
|
|
| message = ( |
| f"Search completed. Found {result.total_found} total, " |
| f"{new_evidence_count} unique new papers." |
| ) |
| if result.errors: |
| message += f" Errors: {'; '.join(result.errors)}" |
|
|
| return { |
| "evidence_ids": new_ids, |
| "messages": [AIMessage(content=message)], |
| } |
|
|
|
|
| async def judge_node( |
| state: ResearchState, embedding_service: EmbeddingServiceProtocol | None = None |
| ) -> dict[str, Any]: |
| """Evaluate evidence and update hypothesis confidence.""" |
| logger.info("judge_node: evaluating evidence") |
|
|
| evidence_context: list[Evidence] = [] |
| if embedding_service: |
| scored_points = await embedding_service.search_similar(state["query"], n_results=20) |
| evidence_context = _results_to_evidence(scored_points) |
|
|
| agent = Agent( |
| model=get_model(), |
| output_type=HypothesisAssessment, |
| system_prompt=HYPOTHESIS_SYSTEM_PROMPT, |
| ) |
|
|
| prompt = await format_hypothesis_prompt( |
| query=state["query"], evidence=evidence_context, embeddings=embedding_service |
| ) |
|
|
| try: |
| result = await agent.run(prompt) |
| assessment = result.output |
|
|
| new_hypotheses = [] |
| for h in assessment.hypotheses: |
| new_hypotheses.append( |
| Hypothesis( |
| id=h.drug, |
| statement=f"{h.drug} -> {h.target} -> {h.pathway} -> {h.effect}", |
| status="proposed", |
| confidence=h.confidence, |
| supporting_evidence_ids=[], |
| contradicting_evidence_ids=[], |
| ) |
| ) |
|
|
| return { |
| "hypotheses": new_hypotheses, |
| "messages": [AIMessage(content=f"Judge: Generated {len(new_hypotheses)} hypotheses.")], |
| "next_step": "resolve", |
| } |
| except Exception as e: |
| logger.error("judge_node failed", error=str(e)) |
| return {"messages": [AIMessage(content=f"Judge Error: {e!s}")], "next_step": "search"} |
|
|
|
|
| async def resolve_node( |
| state: ResearchState, embedding_service: EmbeddingServiceProtocol | None = None |
| ) -> dict[str, Any]: |
| """Handle open conflicts.""" |
| messages = [] |
|
|
| |
| high_conf = [h for h in state["hypotheses"] if h.confidence > 0.8] |
|
|
| if high_conf: |
| messages.append( |
| AIMessage( |
| content=( |
| f"Resolver: Found {len(high_conf)} high confidence hypotheses. " |
| "Conflicts resolved." |
| ) |
| ) |
| ) |
| else: |
| messages.append(AIMessage(content="Resolver: No high confidence hypotheses yet.")) |
|
|
| return {"messages": messages} |
|
|
|
|
| async def synthesize_node( |
| state: ResearchState, embedding_service: EmbeddingServiceProtocol | None = None |
| ) -> dict[str, Any]: |
| """Generate final report.""" |
| logger.info("synthesize_node: generating report") |
|
|
| evidence_context: list[Evidence] = [] |
| if embedding_service: |
| scored_points = await embedding_service.search_similar(state["query"], n_results=50) |
| evidence_context = _results_to_evidence(scored_points) |
|
|
| agent = Agent( |
| model=get_model(), |
| output_type=ResearchReport, |
| system_prompt=REPORT_SYSTEM_PROMPT, |
| ) |
|
|
| |
| mechanism_hypotheses = [_convert_hypothesis_to_mechanism(h) for h in state["hypotheses"]] |
|
|
| prompt = await format_report_prompt( |
| query=state["query"], |
| evidence=evidence_context, |
| hypotheses=mechanism_hypotheses, |
| assessment={}, |
| metadata={"sources": list(set(e.citation.source for e in evidence_context))}, |
| embeddings=embedding_service, |
| ) |
|
|
| try: |
| result = await agent.run(prompt) |
| report = result.output |
| report = validate_references(report, evidence_context) |
|
|
| return {"messages": [AIMessage(content=report.to_markdown())], "next_step": "finish"} |
| except Exception as e: |
| logger.error("synthesize_node failed", error=str(e)) |
| return {"messages": [AIMessage(content=f"Synthesis Error: {e!s}")], "next_step": "finish"} |
|
|
|
|
| async def supervisor_node(state: ResearchState, llm: BaseChatModel | None = None) -> dict[str, Any]: |
| """Route to next node based on state using robust Pydantic parsing.""" |
| if state["iteration_count"] >= state["max_iterations"]: |
| return {"next_step": "synthesize", "iteration_count": state["iteration_count"]} |
|
|
| if llm is None: |
| return {"next_step": "search", "iteration_count": state["iteration_count"] + 1} |
|
|
| parser = PydanticOutputParser(pydantic_object=SupervisorDecision) |
|
|
| prompt = ChatPromptTemplate.from_messages( |
| [ |
| ( |
| "system", |
| "You are the Research Supervisor. Manage the workflow.\n\n" |
| "State Summary:\n" |
| "- Query: {query}\n" |
| "- Hypotheses: {hypo_count}\n" |
| "- Conflicts: {conflict_count}\n" |
| "- Iteration: {iteration}/{max_iter}\n\n" |
| "Decide the next step based on this logic:\n" |
| "1. If there are open conflicts -> 'resolve'\n" |
| "2. If hypotheses are unverified or few -> 'search'\n" |
| "3. If new evidence needs evaluation -> 'judge'\n" |
| "4. If hypotheses are confirmed -> 'synthesize'\n\n" |
| "{format_instructions}", |
| ), |
| ("user", "What is the next step?"), |
| ] |
| ) |
|
|
| chain = prompt | llm | parser |
|
|
| try: |
| |
| decision: SupervisorDecision = await chain.ainvoke( |
| { |
| "query": state["query"], |
| "hypo_count": len(state["hypotheses"]), |
| "conflict_count": len([c for c in state["conflicts"] if c.status == "open"]), |
| "iteration": state["iteration_count"], |
| "max_iter": state["max_iterations"], |
| "format_instructions": parser.get_format_instructions(), |
| } |
| ) |
| return { |
| "next_step": decision.next_step, |
| "iteration_count": state["iteration_count"] + 1, |
| "messages": [AIMessage(content=f"Supervisor: {decision.reasoning}")], |
| } |
| except Exception as e: |
| return { |
| "next_step": "synthesize", |
| "iteration_count": state["iteration_count"] + 1, |
| "messages": [AIMessage(content=f"Supervisor Error: {e!s}. Proceeding to synthesis.")], |
| } |
|
|