File size: 2,903 Bytes
60ce079
 
 
 
 
 
 
 
 
 
e68a66e
60ce079
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from langgraph.graph import StateGraph, END, START
from langchain_core.tools import tool
from agent_state import AgentState
from workflow_nodes import (
    expand_query, retrieve_context, craft_response, 
    score_groundedness, refine_response, check_precision, 
    refine_query, max_iterations_reached
)
from workflow_conditions import should_continue_groundedness, should_continue_precision


def create_workflow() -> StateGraph:
    """Creates the updated workflow for the AI nutrition agent."""
    workflow = StateGraph(AgentState)

    # Add processing nodes
    workflow.add_node("expand_query", expand_query)
    workflow.add_node("retrieve_context", retrieve_context)
    workflow.add_node("craft_response", craft_response)
    workflow.add_node("score_groundedness", score_groundedness)
    workflow.add_node("refine_response", refine_response)
    workflow.add_node("check_precision", check_precision)
    workflow.add_node("refine_query", refine_query)
    workflow.add_node("max_iterations_reached", max_iterations_reached)

    # Main flow edges
    workflow.add_edge(START, "expand_query")
    workflow.add_edge("expand_query", "retrieve_context")
    workflow.add_edge("retrieve_context", "craft_response")
    workflow.add_edge("craft_response", "score_groundedness")

    # Conditional edges based on groundedness check
    workflow.add_conditional_edges(
        "score_groundedness",
        should_continue_groundedness,
        {
            "check_precision": "check_precision",
            "refine_response": "refine_response",
            "max_iterations_reached": "max_iterations_reached"
        }
    )
    workflow.add_edge("refine_response", "craft_response")

    # Conditional edges based on precision check
    workflow.add_conditional_edges(
        "check_precision",
        should_continue_precision,
        {
            "pass": END,
            "refine_query": "refine_query",
            "max_iterations_reached": "max_iterations_reached"
        }
    )
    workflow.add_edge("refine_query", "expand_query")
    workflow.add_edge("max_iterations_reached", END)

    return workflow

# Create workflow instance
WORKFLOW_APP = create_workflow().compile()

@tool
def agentic_rag(query: str):
    """
    Runs the RAG-based agent with conversation history for context-aware responses.

    Args:
        query (str): The current user query.

    Returns:
        Dict[str, Any]: The updated state with the generated response and conversation history.
    """
    inputs = {
        "query": query,
        "expanded_query": "",
        "context": [],
        "response": "",
        "precision_score": 0.0,
        "groundedness_score": 0.0,
        "groundedness_loop_count": 0,
        "precision_loop_count": 0,
        "feedback": "",
        "query_feedback": "",
        "loop_max_iter": 2
    }

    output = WORKFLOW_APP.invoke(inputs)
    return output