File size: 5,297 Bytes
1e732dd
 
 
 
 
 
 
 
 
696f787
1e732dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
696f787
1e732dd
 
696f787
1e732dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
696f787
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
"""
MediGuard AI — Agentic RAG Orchestrator

LangGraph StateGraph that wires all nodes into the guardrail → retrieve → grade → generate pipeline.
"""

from __future__ import annotations

import logging
from functools import partial
from typing import Any

from langgraph.graph import END, StateGraph

from src.services.agents.context import AgenticContext
from src.services.agents.nodes.generate_answer_node import generate_answer_node
from src.services.agents.nodes.grade_documents_node import grade_documents_node
from src.services.agents.nodes.guardrail_node import guardrail_node
from src.services.agents.nodes.out_of_scope_node import out_of_scope_node
from src.services.agents.nodes.retrieve_node import retrieve_node
from src.services.agents.nodes.rewrite_query_node import rewrite_query_node
from src.services.agents.state import AgenticRAGState

logger = logging.getLogger(__name__)

# ---------------------------------------------------------------------------
# Edge routing helpers
# ---------------------------------------------------------------------------


def _route_after_guardrail(state: dict) -> str:
    """Decide path after guardrail evaluation."""
    if state.get("routing_decision") == "analyze":
        # Biomarker analysis pathway — goes straight to retrieve
        return "retrieve"
    if state.get("is_in_scope"):
        return "retrieve"
    return "out_of_scope"


def _route_after_grading(state: dict) -> str:
    """Decide whether to rewrite query or proceed to generation."""
    if state.get("needs_rewrite"):
        return "rewrite_query"
    if not state.get("relevant_documents"):
        return "generate_answer"  # will produce a "no evidence found" answer
    return "generate_answer"


# ---------------------------------------------------------------------------
# Graph builder
# ---------------------------------------------------------------------------


def build_agentic_rag_graph(context: AgenticContext) -> Any:
    """Construct the compiled LangGraph for the agentic RAG pipeline.

    Parameters
    ----------
    context:
        Runtime dependencies (LLM, OpenSearch, embeddings, cache, tracer).

    Returns
    -------
    Compiled LangGraph graph ready for ``.invoke()`` / ``.stream()``.
    """
    workflow = StateGraph(AgenticRAGState)

    # Bind context to every node via functools.partial
    workflow.add_node("guardrail", partial(guardrail_node, context=context))
    workflow.add_node("retrieve", partial(retrieve_node, context=context))
    workflow.add_node("grade_documents", partial(grade_documents_node, context=context))
    workflow.add_node("rewrite_query", partial(rewrite_query_node, context=context))
    workflow.add_node("generate_answer", partial(generate_answer_node, context=context))
    workflow.add_node("out_of_scope", partial(out_of_scope_node, context=context))

    # Entry point
    workflow.set_entry_point("guardrail")

    # Conditional edges
    workflow.add_conditional_edges(
        "guardrail",
        _route_after_guardrail,
        {
            "retrieve": "retrieve",
            "out_of_scope": "out_of_scope",
        },
    )

    workflow.add_edge("retrieve", "grade_documents")

    workflow.add_conditional_edges(
        "grade_documents",
        _route_after_grading,
        {
            "rewrite_query": "rewrite_query",
            "generate_answer": "generate_answer",
        },
    )

    # After rewrite, loop back to retrieve
    workflow.add_edge("rewrite_query", "retrieve")

    # Terminal edges
    workflow.add_edge("generate_answer", END)
    workflow.add_edge("out_of_scope", END)

    return workflow.compile()


# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------


class AgenticRAGService:
    """High-level wrapper around the compiled RAG graph."""

    def __init__(self, context: AgenticContext) -> None:
        self._context = context
        self._graph = build_agentic_rag_graph(context)

    def ask(
        self,
        query: str,
        biomarkers: dict | None = None,
        patient_context: str = "",
    ) -> dict:
        """Run the full agentic RAG pipeline and return the final state."""
        initial_state: dict[str, Any] = {
            "query": query,
            "biomarkers": biomarkers,
            "patient_context": patient_context,
            "errors": [],
        }

        trace_obj = None
        try:
            if self._context.tracer:
                trace_obj = self._context.tracer.trace(
                    name="agentic_rag_ask",
                    metadata={"query": query},
                )
            result = self._graph.invoke(initial_state)
            return result
        except Exception as exc:
            logger.error("Agentic RAG pipeline failed: %s", exc)
            return {
                **initial_state,
                "final_answer": (
                    "I apologize, but I'm temporarily unable to process your request. "
                    "Please consult a healthcare professional."
                ),
                "errors": [str(exc)],
            }
        finally:
            if self._context.tracer:
                self._context.tracer.flush()