Spaces:
Sleeping
Sleeping
| """ | |
| MediGuard AI — Agentic RAG Orchestrator | |
| LangGraph StateGraph that wires all nodes into the guardrail → retrieve → grade → generate pipeline. | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| from functools import partial | |
| from typing import Any | |
| from langgraph.graph import END, StateGraph | |
| from src.services.agents.context import AgenticContext | |
| from src.services.agents.nodes.generate_answer_node import generate_answer_node | |
| from src.services.agents.nodes.grade_documents_node import grade_documents_node | |
| from src.services.agents.nodes.guardrail_node import guardrail_node | |
| from src.services.agents.nodes.out_of_scope_node import out_of_scope_node | |
| from src.services.agents.nodes.retrieve_node import retrieve_node | |
| from src.services.agents.nodes.rewrite_query_node import rewrite_query_node | |
| from src.services.agents.state import AgenticRAGState | |
| logger = logging.getLogger(__name__) | |
| # --------------------------------------------------------------------------- | |
| # Edge routing helpers | |
| # --------------------------------------------------------------------------- | |
| def _route_after_guardrail(state: dict) -> str: | |
| """Decide path after guardrail evaluation.""" | |
| if state.get("routing_decision") == "analyze": | |
| # Biomarker analysis pathway — goes straight to retrieve | |
| return "retrieve" | |
| if state.get("is_in_scope"): | |
| return "retrieve" | |
| return "out_of_scope" | |
| def _route_after_grading(state: dict) -> str: | |
| """Decide whether to rewrite query or proceed to generation.""" | |
| if state.get("needs_rewrite"): | |
| return "rewrite_query" | |
| if not state.get("relevant_documents"): | |
| return "generate_answer" # will produce a "no evidence found" answer | |
| return "generate_answer" | |
| # --------------------------------------------------------------------------- | |
| # Graph builder | |
| # --------------------------------------------------------------------------- | |
| def build_agentic_rag_graph(context: AgenticContext) -> Any: | |
| """Construct the compiled LangGraph for the agentic RAG pipeline. | |
| Parameters | |
| ---------- | |
| context: | |
| Runtime dependencies (LLM, OpenSearch, embeddings, cache, tracer). | |
| Returns | |
| ------- | |
| Compiled LangGraph graph ready for ``.invoke()`` / ``.stream()``. | |
| """ | |
| workflow = StateGraph(AgenticRAGState) | |
| # Bind context to every node via functools.partial | |
| workflow.add_node("guardrail", partial(guardrail_node, context=context)) | |
| workflow.add_node("retrieve", partial(retrieve_node, context=context)) | |
| workflow.add_node("grade_documents", partial(grade_documents_node, context=context)) | |
| workflow.add_node("rewrite_query", partial(rewrite_query_node, context=context)) | |
| workflow.add_node("generate_answer", partial(generate_answer_node, context=context)) | |
| workflow.add_node("out_of_scope", partial(out_of_scope_node, context=context)) | |
| # Entry point | |
| workflow.set_entry_point("guardrail") | |
| # Conditional edges | |
| workflow.add_conditional_edges( | |
| "guardrail", | |
| _route_after_guardrail, | |
| { | |
| "retrieve": "retrieve", | |
| "out_of_scope": "out_of_scope", | |
| }, | |
| ) | |
| workflow.add_edge("retrieve", "grade_documents") | |
| workflow.add_conditional_edges( | |
| "grade_documents", | |
| _route_after_grading, | |
| { | |
| "rewrite_query": "rewrite_query", | |
| "generate_answer": "generate_answer", | |
| }, | |
| ) | |
| # After rewrite, loop back to retrieve | |
| workflow.add_edge("rewrite_query", "retrieve") | |
| # Terminal edges | |
| workflow.add_edge("generate_answer", END) | |
| workflow.add_edge("out_of_scope", END) | |
| return workflow.compile() | |
| # --------------------------------------------------------------------------- | |
| # Public API | |
| # --------------------------------------------------------------------------- | |
| class AgenticRAGService: | |
| """High-level wrapper around the compiled RAG graph.""" | |
| def __init__(self, context: AgenticContext) -> None: | |
| self._context = context | |
| self._graph = build_agentic_rag_graph(context) | |
| def ask( | |
| self, | |
| query: str, | |
| biomarkers: dict | None = None, | |
| patient_context: str = "", | |
| ) -> dict: | |
| """Run the full agentic RAG pipeline and return the final state.""" | |
| initial_state: dict[str, Any] = { | |
| "query": query, | |
| "biomarkers": biomarkers, | |
| "patient_context": patient_context, | |
| "errors": [], | |
| } | |
| trace_obj = None | |
| try: | |
| if self._context.tracer: | |
| trace_obj = self._context.tracer.trace( | |
| name="agentic_rag_ask", | |
| metadata={"query": query}, | |
| ) | |
| result = self._graph.invoke(initial_state) | |
| return result | |
| except Exception as exc: | |
| logger.error("Agentic RAG pipeline failed: %s", exc) | |
| return { | |
| **initial_state, | |
| "final_answer": ( | |
| "I apologize, but I'm temporarily unable to process your request. " | |
| "Please consult a healthcare professional." | |
| ), | |
| "errors": [str(exc)], | |
| } | |
| finally: | |
| if self._context.tracer: | |
| self._context.tracer.flush() | |