Spaces:
Running
Running
| """Graph node implementations for DeepBoner research.""" | |
| from typing import Any, Literal | |
| import structlog | |
| from langchain_core.language_models.chat_models import BaseChatModel | |
| from langchain_core.messages import AIMessage | |
| from langchain_core.output_parsers import PydanticOutputParser | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from pydantic import BaseModel, Field | |
| from pydantic_ai import Agent | |
| from src.agent_factory.judges import get_model | |
| from src.agents.graph.state import Hypothesis, ResearchState | |
| from src.prompts.hypothesis import SYSTEM_PROMPT as HYPOTHESIS_SYSTEM_PROMPT | |
| from src.prompts.hypothesis import format_hypothesis_prompt | |
| from src.prompts.report import SYSTEM_PROMPT as REPORT_SYSTEM_PROMPT | |
| from src.prompts.report import format_report_prompt | |
| from src.services.embedding_protocol import EmbeddingServiceProtocol | |
| from src.tools.base import SearchTool | |
| from src.tools.clinicaltrials import ClinicalTrialsTool | |
| from src.tools.europepmc import EuropePMCTool | |
| from src.tools.pubmed import PubMedTool | |
| from src.tools.search_handler import SearchHandler | |
| from src.utils.citation_validator import validate_references | |
| from src.utils.models import ( | |
| Citation, | |
| Evidence, | |
| HypothesisAssessment, | |
| MechanismHypothesis, | |
| ResearchReport, | |
| ) | |
| logger = structlog.get_logger() | |
| def _convert_hypothesis_to_mechanism(h: Hypothesis) -> MechanismHypothesis: | |
| """Convert state Hypothesis to MechanismHypothesis for report generation. | |
| The state Hypothesis stores the mechanism as a statement like: | |
| "drug -> target -> pathway -> effect" | |
| We parse this back into structured MechanismHypothesis fields. | |
| """ | |
| # Parse statement format: "drug -> target -> pathway -> effect" | |
| # Handle both " -> " (standard) and "->" (compact) separators | |
| separator = " -> " if " -> " in h.statement else "->" | |
| parts = [p.strip() for p in h.statement.split(separator)] | |
| # Validate: exactly 4 non-empty parts | |
| if len(parts) == 4 and all(parts): | |
| drug, target, pathway, effect = parts | |
| elif len(parts) > 4 and all(parts[:4]): | |
| # More than 4 parts: join extras into effect | |
| drug, target, pathway = parts[0], parts[1], parts[2] | |
| effect = f"{separator}".join(parts[3:]) | |
| logger.debug( | |
| "Hypothesis has extra parts, joined into effect", | |
| hypothesis_id=h.id, | |
| parts_count=len(parts), | |
| ) | |
| else: | |
| # Log parsing failure for debugging | |
| logger.warning( | |
| "Failed to parse hypothesis statement format", | |
| hypothesis_id=h.id, | |
| statement=h.statement[:100], # Truncate for log safety | |
| parts_count=len(parts), | |
| ) | |
| # Use meaningful fallback values | |
| drug = "Unknown" | |
| target = "Unknown" | |
| pathway = "Unknown" | |
| effect = h.statement.strip() if h.statement else "Unknown effect" | |
| return MechanismHypothesis( | |
| drug=drug, | |
| target=target, | |
| pathway=pathway, | |
| effect=effect, | |
| confidence=h.confidence, | |
| supporting_evidence=h.supporting_evidence_ids, | |
| contradicting_evidence=h.contradicting_evidence_ids, | |
| ) | |
| def _results_to_evidence(results: list[dict[str, Any]]) -> list[Evidence]: | |
| """Convert search_similar results to Evidence objects. | |
| Extracted helper to avoid code duplication between judge_node and synthesize_node. | |
| """ | |
| evidence_list = [] | |
| for r in results: | |
| meta = r.get("metadata", {}) | |
| authors_str = meta.get("authors", "") | |
| author_list = [a.strip() for a in authors_str.split(",")] if authors_str else [] | |
| evidence_list.append( | |
| Evidence( | |
| content=r.get("content", ""), | |
| citation=Citation( | |
| url=r.get("id", ""), | |
| title=meta.get("title", "Unknown"), | |
| source=meta.get("source", "Unknown"), | |
| date=meta.get("date", ""), | |
| authors=author_list, | |
| ), | |
| ) | |
| ) | |
| return evidence_list | |
| # --- Supervisor Output Schema --- | |
| class SupervisorDecision(BaseModel): | |
| """The decision made by the supervisor.""" | |
| next_step: Literal["search", "judge", "resolve", "synthesize", "finish"] = Field( | |
| description="The next step to take in the research process." | |
| ) | |
| reasoning: str = Field(description="Reasoning for this decision.") | |
| # --- Nodes --- | |
| async def search_node( | |
| state: ResearchState, embedding_service: EmbeddingServiceProtocol | None = None | |
| ) -> dict[str, Any]: | |
| """Execute search across all sources.""" | |
| query = state["query"] | |
| logger.info("search_node: executing search", query=query) | |
| # Initialize tools | |
| tools: list[SearchTool] = [PubMedTool(), ClinicalTrialsTool(), EuropePMCTool()] | |
| handler = SearchHandler(tools=tools) | |
| # Execute search | |
| result = await handler.execute(query) | |
| new_evidence_count = 0 | |
| new_ids = [] | |
| if embedding_service and result.evidence: | |
| # Deduplicate and store (deduplicate() already calls add_evidence() internally) | |
| unique_evidence = await embedding_service.deduplicate(result.evidence) | |
| # Track IDs for state (evidence already stored by deduplicate()) | |
| new_ids = [ev.citation.url for ev in unique_evidence] | |
| new_evidence_count = len(unique_evidence) | |
| else: | |
| new_evidence_count = len(result.evidence) | |
| message = ( | |
| f"Search completed. Found {result.total_found} total, " | |
| f"{new_evidence_count} unique new papers." | |
| ) | |
| if result.errors: | |
| message += f" Errors: {'; '.join(result.errors)}" | |
| return { | |
| "evidence_ids": new_ids, | |
| "messages": [AIMessage(content=message)], | |
| } | |
| async def judge_node( | |
| state: ResearchState, embedding_service: EmbeddingServiceProtocol | None = None | |
| ) -> dict[str, Any]: | |
| """Evaluate evidence and update hypothesis confidence.""" | |
| logger.info("judge_node: evaluating evidence") | |
| evidence_context: list[Evidence] = [] | |
| if embedding_service: | |
| scored_points = await embedding_service.search_similar(state["query"], n_results=20) | |
| evidence_context = _results_to_evidence(scored_points) | |
| agent = Agent( | |
| model=get_model(), | |
| output_type=HypothesisAssessment, | |
| system_prompt=HYPOTHESIS_SYSTEM_PROMPT, | |
| ) | |
| prompt = await format_hypothesis_prompt( | |
| query=state["query"], evidence=evidence_context, embeddings=embedding_service | |
| ) | |
| try: | |
| result = await agent.run(prompt) | |
| assessment = result.output | |
| new_hypotheses = [] | |
| for h in assessment.hypotheses: | |
| new_hypotheses.append( | |
| Hypothesis( | |
| id=h.drug, | |
| statement=f"{h.drug} -> {h.target} -> {h.pathway} -> {h.effect}", | |
| status="proposed", | |
| confidence=h.confidence, | |
| supporting_evidence_ids=[], | |
| contradicting_evidence_ids=[], | |
| ) | |
| ) | |
| return { | |
| "hypotheses": new_hypotheses, | |
| "messages": [AIMessage(content=f"Judge: Generated {len(new_hypotheses)} hypotheses.")], | |
| "next_step": "resolve", | |
| } | |
| except Exception as e: | |
| logger.error("judge_node failed", error=str(e)) | |
| return {"messages": [AIMessage(content=f"Judge Error: {e!s}")], "next_step": "search"} | |
| async def resolve_node( | |
| state: ResearchState, embedding_service: EmbeddingServiceProtocol | None = None | |
| ) -> dict[str, Any]: | |
| """Handle open conflicts.""" | |
| messages = [] | |
| # Access attributes with dot notation because items are Pydantic models | |
| high_conf = [h for h in state["hypotheses"] if h.confidence > 0.8] | |
| if high_conf: | |
| messages.append( | |
| AIMessage( | |
| content=( | |
| f"Resolver: Found {len(high_conf)} high confidence hypotheses. " | |
| "Conflicts resolved." | |
| ) | |
| ) | |
| ) | |
| else: | |
| messages.append(AIMessage(content="Resolver: No high confidence hypotheses yet.")) | |
| return {"messages": messages} | |
| async def synthesize_node( | |
| state: ResearchState, embedding_service: EmbeddingServiceProtocol | None = None | |
| ) -> dict[str, Any]: | |
| """Generate final report.""" | |
| logger.info("synthesize_node: generating report") | |
| evidence_context: list[Evidence] = [] | |
| if embedding_service: | |
| scored_points = await embedding_service.search_similar(state["query"], n_results=50) | |
| evidence_context = _results_to_evidence(scored_points) | |
| agent = Agent( | |
| model=get_model(), | |
| output_type=ResearchReport, | |
| system_prompt=REPORT_SYSTEM_PROMPT, | |
| ) | |
| # Convert state hypotheses to MechanismHypothesis for report generation | |
| mechanism_hypotheses = [_convert_hypothesis_to_mechanism(h) for h in state["hypotheses"]] | |
| prompt = await format_report_prompt( | |
| query=state["query"], | |
| evidence=evidence_context, | |
| hypotheses=mechanism_hypotheses, | |
| assessment={}, | |
| metadata={"sources": list(set(e.citation.source for e in evidence_context))}, | |
| embeddings=embedding_service, | |
| ) | |
| try: | |
| result = await agent.run(prompt) | |
| report = result.output | |
| report = validate_references(report, evidence_context) | |
| return {"messages": [AIMessage(content=report.to_markdown())], "next_step": "finish"} | |
| except Exception as e: | |
| logger.error("synthesize_node failed", error=str(e)) | |
| return {"messages": [AIMessage(content=f"Synthesis Error: {e!s}")], "next_step": "finish"} | |
| async def supervisor_node(state: ResearchState, llm: BaseChatModel | None = None) -> dict[str, Any]: | |
| """Route to next node based on state using robust Pydantic parsing.""" | |
| if state["iteration_count"] >= state["max_iterations"]: | |
| return {"next_step": "synthesize", "iteration_count": state["iteration_count"]} | |
| if llm is None: | |
| return {"next_step": "search", "iteration_count": state["iteration_count"] + 1} | |
| parser = PydanticOutputParser(pydantic_object=SupervisorDecision) | |
| prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| ( | |
| "system", | |
| "You are the Research Supervisor. Manage the workflow.\n\n" | |
| "State Summary:\n" | |
| "- Query: {query}\n" | |
| "- Hypotheses: {hypo_count}\n" | |
| "- Conflicts: {conflict_count}\n" | |
| "- Iteration: {iteration}/{max_iter}\n\n" | |
| "Decide the next step based on this logic:\n" | |
| "1. If there are open conflicts -> 'resolve'\n" | |
| "2. If hypotheses are unverified or few -> 'search'\n" | |
| "3. If new evidence needs evaluation -> 'judge'\n" | |
| "4. If hypotheses are confirmed -> 'synthesize'\n\n" | |
| "{format_instructions}", | |
| ), | |
| ("user", "What is the next step?"), | |
| ] | |
| ) | |
| chain = prompt | llm | parser | |
| try: | |
| # Note: state["conflicts"] contains Pydantic models, so use dot notation | |
| decision: SupervisorDecision = await chain.ainvoke( | |
| { | |
| "query": state["query"], | |
| "hypo_count": len(state["hypotheses"]), | |
| "conflict_count": len([c for c in state["conflicts"] if c.status == "open"]), | |
| "iteration": state["iteration_count"], | |
| "max_iter": state["max_iterations"], | |
| "format_instructions": parser.get_format_instructions(), | |
| } | |
| ) | |
| return { | |
| "next_step": decision.next_step, | |
| "iteration_count": state["iteration_count"] + 1, | |
| "messages": [AIMessage(content=f"Supervisor: {decision.reasoning}")], | |
| } | |
| except Exception as e: | |
| return { | |
| "next_step": "synthesize", | |
| "iteration_count": state["iteration_count"] + 1, | |
| "messages": [AIMessage(content=f"Supervisor Error: {e!s}. Proceeding to synthesis.")], | |
| } | |