Spaces:
Sleeping
Sleeping
| from langgraph.graph import StateGraph, END, START | |
| from langchain_core.tools import tool | |
| from agent_state import AgentState | |
| from workflow_nodes import ( | |
| expand_query, retrieve_context, craft_response, | |
| score_groundedness, refine_response, check_precision, | |
| refine_query, max_iterations_reached | |
| ) | |
| from workflow_conditions import should_continue_groundedness, should_continue_precision | |
| def create_workflow() -> StateGraph: | |
| """Creates the updated workflow for the AI nutrition agent.""" | |
| workflow = StateGraph(AgentState) | |
| # Add processing nodes | |
| workflow.add_node("expand_query", expand_query) | |
| workflow.add_node("retrieve_context", retrieve_context) | |
| workflow.add_node("craft_response", craft_response) | |
| workflow.add_node("score_groundedness", score_groundedness) | |
| workflow.add_node("refine_response", refine_response) | |
| workflow.add_node("check_precision", check_precision) | |
| workflow.add_node("refine_query", refine_query) | |
| workflow.add_node("max_iterations_reached", max_iterations_reached) | |
| # Main flow edges | |
| workflow.add_edge(START, "expand_query") | |
| workflow.add_edge("expand_query", "retrieve_context") | |
| workflow.add_edge("retrieve_context", "craft_response") | |
| workflow.add_edge("craft_response", "score_groundedness") | |
| # Conditional edges based on groundedness check | |
| workflow.add_conditional_edges( | |
| "score_groundedness", | |
| should_continue_groundedness, | |
| { | |
| "check_precision": "check_precision", | |
| "refine_response": "refine_response", | |
| "max_iterations_reached": "max_iterations_reached" | |
| } | |
| ) | |
| workflow.add_edge("refine_response", "craft_response") | |
| # Conditional edges based on precision check | |
| workflow.add_conditional_edges( | |
| "check_precision", | |
| should_continue_precision, | |
| { | |
| "pass": END, | |
| "refine_query": "refine_query", | |
| "max_iterations_reached": "max_iterations_reached" | |
| } | |
| ) | |
| workflow.add_edge("refine_query", "expand_query") | |
| workflow.add_edge("max_iterations_reached", END) | |
| return workflow | |
| # Create workflow instance | |
| WORKFLOW_APP = create_workflow().compile() | |
| def agentic_rag(query: str): | |
| """ | |
| Runs the RAG-based agent with conversation history for context-aware responses. | |
| Args: | |
| query (str): The current user query. | |
| Returns: | |
| Dict[str, Any]: The updated state with the generated response and conversation history. | |
| """ | |
| inputs = { | |
| "query": query, | |
| "expanded_query": "", | |
| "context": [], | |
| "response": "", | |
| "precision_score": 0.0, | |
| "groundedness_score": 0.0, | |
| "groundedness_loop_count": 0, | |
| "precision_loop_count": 0, | |
| "feedback": "", | |
| "query_feedback": "", | |
| "loop_max_iter": 2 | |
| } | |
| output = WORKFLOW_APP.invoke(inputs) | |
| return output | |