modular-rag-bot / core /workflow.py
gl-kp's picture
Update core/workflow.py
e68a66e verified
from langgraph.graph import StateGraph, END, START
from langchain_core.tools import tool
from agent_state import AgentState
from workflow_nodes import (
expand_query, retrieve_context, craft_response,
score_groundedness, refine_response, check_precision,
refine_query, max_iterations_reached
)
from workflow_conditions import should_continue_groundedness, should_continue_precision
def create_workflow() -> StateGraph:
"""Creates the updated workflow for the AI nutrition agent."""
workflow = StateGraph(AgentState)
# Add processing nodes
workflow.add_node("expand_query", expand_query)
workflow.add_node("retrieve_context", retrieve_context)
workflow.add_node("craft_response", craft_response)
workflow.add_node("score_groundedness", score_groundedness)
workflow.add_node("refine_response", refine_response)
workflow.add_node("check_precision", check_precision)
workflow.add_node("refine_query", refine_query)
workflow.add_node("max_iterations_reached", max_iterations_reached)
# Main flow edges
workflow.add_edge(START, "expand_query")
workflow.add_edge("expand_query", "retrieve_context")
workflow.add_edge("retrieve_context", "craft_response")
workflow.add_edge("craft_response", "score_groundedness")
# Conditional edges based on groundedness check
workflow.add_conditional_edges(
"score_groundedness",
should_continue_groundedness,
{
"check_precision": "check_precision",
"refine_response": "refine_response",
"max_iterations_reached": "max_iterations_reached"
}
)
workflow.add_edge("refine_response", "craft_response")
# Conditional edges based on precision check
workflow.add_conditional_edges(
"check_precision",
should_continue_precision,
{
"pass": END,
"refine_query": "refine_query",
"max_iterations_reached": "max_iterations_reached"
}
)
workflow.add_edge("refine_query", "expand_query")
workflow.add_edge("max_iterations_reached", END)
return workflow
# Create workflow instance
WORKFLOW_APP = create_workflow().compile()
@tool
def agentic_rag(query: str):
"""
Runs the RAG-based agent with conversation history for context-aware responses.
Args:
query (str): The current user query.
Returns:
Dict[str, Any]: The updated state with the generated response and conversation history.
"""
inputs = {
"query": query,
"expanded_query": "",
"context": [],
"response": "",
"precision_score": 0.0,
"groundedness_score": 0.0,
"groundedness_loop_count": 0,
"precision_loop_count": 0,
"feedback": "",
"query_feedback": "",
"loop_max_iter": 2
}
output = WORKFLOW_APP.invoke(inputs)
return output