Spaces:
Sleeping
Sleeping
File size: 2,903 Bytes
60ce079 e68a66e 60ce079 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
from langgraph.graph import StateGraph, END, START
from langchain_core.tools import tool
from agent_state import AgentState
from workflow_nodes import (
expand_query, retrieve_context, craft_response,
score_groundedness, refine_response, check_precision,
refine_query, max_iterations_reached
)
from workflow_conditions import should_continue_groundedness, should_continue_precision
def create_workflow() -> StateGraph:
"""Creates the updated workflow for the AI nutrition agent."""
workflow = StateGraph(AgentState)
# Add processing nodes
workflow.add_node("expand_query", expand_query)
workflow.add_node("retrieve_context", retrieve_context)
workflow.add_node("craft_response", craft_response)
workflow.add_node("score_groundedness", score_groundedness)
workflow.add_node("refine_response", refine_response)
workflow.add_node("check_precision", check_precision)
workflow.add_node("refine_query", refine_query)
workflow.add_node("max_iterations_reached", max_iterations_reached)
# Main flow edges
workflow.add_edge(START, "expand_query")
workflow.add_edge("expand_query", "retrieve_context")
workflow.add_edge("retrieve_context", "craft_response")
workflow.add_edge("craft_response", "score_groundedness")
# Conditional edges based on groundedness check
workflow.add_conditional_edges(
"score_groundedness",
should_continue_groundedness,
{
"check_precision": "check_precision",
"refine_response": "refine_response",
"max_iterations_reached": "max_iterations_reached"
}
)
workflow.add_edge("refine_response", "craft_response")
# Conditional edges based on precision check
workflow.add_conditional_edges(
"check_precision",
should_continue_precision,
{
"pass": END,
"refine_query": "refine_query",
"max_iterations_reached": "max_iterations_reached"
}
)
workflow.add_edge("refine_query", "expand_query")
workflow.add_edge("max_iterations_reached", END)
return workflow
# Create workflow instance
WORKFLOW_APP = create_workflow().compile()
@tool
def agentic_rag(query: str):
"""
Runs the RAG-based agent with conversation history for context-aware responses.
Args:
query (str): The current user query.
Returns:
Dict[str, Any]: The updated state with the generated response and conversation history.
"""
inputs = {
"query": query,
"expanded_query": "",
"context": [],
"response": "",
"precision_score": 0.0,
"groundedness_score": 0.0,
"groundedness_loop_count": 0,
"precision_loop_count": 0,
"feedback": "",
"query_feedback": "",
"loop_max_iter": 2
}
output = WORKFLOW_APP.invoke(inputs)
return output
|