mealgraph / workflow.py
moazeldegwy's picture
Simplify topology to 3 agents + 2 tools
1933348
"""LangGraph wiring for the Coach ↔ action loop.
Two-node graph: ``coach`` produces a single typed decision (``CoachDecision``)
and ``execute_action`` dispatches it to the right agent / tool / memory
write, then loops back. Terminates on ``compose_response`` or ``ask_user``
(or when ``max_turns`` is hit).
"""
from __future__ import annotations
import json
from datetime import datetime
from typing import Any, Dict
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, StateGraph
from config import get_settings
from logging_setup import get_logger
from state import NutritionState
from utils import FileCheckpointSaver, set_nested, update_memory_partition
_logger = get_logger("workflow")
def should_continue(state: NutritionState) -> str:
"""Edge predicate: stop on terminal action or when we hit max_turns."""
current = state.get("current_action") or {}
if current.get("action") in {"compose_response", "ask_user"}:
return "end"
if state["num_turns"] >= state["max_turns"]:
return "end"
return "execute_action"
def coach_node(state: NutritionState, coach_agent) -> NutritionState:
return coach_agent.handle_task(state)
def execute_action_node(state: NutritionState, agents: Dict[str, Any], tools: Dict[str, Any]) -> NutritionState:
action = state.get("current_action") or {}
action_name = action.get("action")
params = action.get("params", {}) or {}
if not action_name:
return state
settings = get_settings()
if settings.debug_mode:
_logger.debug("Executing Action: %s", action_name)
else:
if action_name == "ask_user":
_logger.info("❓ Asking user: %s", params.get("prompt"))
elif action_name == "write_memory":
_logger.info("Writing to memory partition: %s", params.get("partition"))
if action.get("_parse_error"):
error_message = "I encountered an error processing the request. Let me try a different approach."
state["conversation_history"].append({"role": "assistant", "content": error_message})
return {**state, "agent_result": error_message}
if "previous_actions" not in state:
state["previous_actions"] = []
try:
if action_name == "call_agent":
agent_name = params["agent_name"]
task = params["task"]
agent_result = agents[agent_name].handle_task(task, state["memory"])
success_message = (
f"{agent_name} task completed and stored in the memory successfully"
if agent_result
else f"{agent_name} task failed"
)
state["previous_actions"].append(f"Called agent {agent_name} with task: {task}")
return {**state, "agent_result": success_message}
if action_name == "write_memory":
partition = params["partition"]
raw_data = params["data"]
# ``data`` is normally a dict; accept a JSON-encoded string too so
# alternative SDK shapes work without special-casing the agents.
if isinstance(raw_data, str):
try:
data = json.loads(raw_data)
except json.JSONDecodeError as decode_err:
raise ValueError(
f"write_memory.data must be an object or JSON string; got: {raw_data!r}"
) from decode_err
else:
data = raw_data
if not isinstance(data, dict):
raise ValueError(f"write_memory.data must be an object, got {type(data).__name__}")
updated_data = {**data, "last_updated": datetime.now().isoformat()}
# Top-level partitions (user_profile, medical_history, …) are
# merged so a partial update does not erase pre-existing keys.
# Dotted paths (``flags_and_assessments.last_validation``) are
# treated as a leaf assignment.
if "." in partition:
set_nested(state["memory"], partition, updated_data)
else:
update_memory_partition(state["memory"], partition, updated_data)
state["previous_actions"].append(f"Wrote to memory partition: {partition}")
return {**state, "agent_result": "Memory updated successfully"}
if action_name == "compose_response":
response_text = params.get("text") or params.get("response")
if not response_text:
raise ValueError("Missing 'text' or 'response' in params for compose_response")
state["conversation_history"].append({"role": "assistant", "content": response_text})
state["previous_actions"].append("Composed response to user")
return {**state, "agent_result": response_text}
if action_name == "ask_user":
prompt_text = params["prompt"]
state["conversation_history"].append({"role": "assistant", "content": prompt_text})
state["previous_actions"].append(f"Asked user: {prompt_text}")
return {**state, "agent_result": "User prompted for input"}
state["previous_actions"].append(f"Executed {action_name} with params: {params}")
return {**state, "agent_result": f"Unknown action: {action_name}"}
except Exception as e: # noqa: BLE001
_logger.exception("Error executing %s", action_name)
state["previous_actions"].append(f"Attempted {action_name} with params: {params}")
return {**state, "agent_result": f"Error executing {action_name}: {str(e)}"}
def setup_workflow(coach_agent, agents: Dict[str, Any], tools: Dict[str, Any], persistence_dir: str | None = None):
workflow = StateGraph(NutritionState)
workflow.add_node("coach", lambda state: coach_node(state, coach_agent))
workflow.add_node("execute_action", lambda state: execute_action_node(state, agents, tools))
workflow.set_entry_point("coach")
workflow.add_edge("coach", "execute_action")
workflow.add_conditional_edges(
"execute_action",
should_continue,
{"execute_action": "coach", "end": END},
)
if persistence_dir:
checkpointer = FileCheckpointSaver(persistence_dir)
_logger.info("MAS workflow compiled with file-based persistence at %s.", persistence_dir)
else:
checkpointer = MemorySaver()
_logger.info("MAS workflow compiled with in-memory persistence.")
return workflow.compile(checkpointer=checkpointer)