"""LangGraph wiring for the Coach ↔ action loop. Two-node graph: ``coach`` produces a single typed decision (``CoachDecision``) and ``execute_action`` dispatches it to the right agent / tool / memory write, then loops back. Terminates on ``compose_response`` or ``ask_user`` (or when ``max_turns`` is hit). """ from __future__ import annotations import json from datetime import datetime from typing import Any, Dict from langgraph.checkpoint.memory import MemorySaver from langgraph.graph import END, StateGraph from config import get_settings from logging_setup import get_logger from state import NutritionState from utils import FileCheckpointSaver, set_nested, update_memory_partition _logger = get_logger("workflow") def should_continue(state: NutritionState) -> str: """Edge predicate: stop on terminal action or when we hit max_turns.""" current = state.get("current_action") or {} if current.get("action") in {"compose_response", "ask_user"}: return "end" if state["num_turns"] >= state["max_turns"]: return "end" return "execute_action" def coach_node(state: NutritionState, coach_agent) -> NutritionState: return coach_agent.handle_task(state) def execute_action_node(state: NutritionState, agents: Dict[str, Any], tools: Dict[str, Any]) -> NutritionState: action = state.get("current_action") or {} action_name = action.get("action") params = action.get("params", {}) or {} if not action_name: return state settings = get_settings() if settings.debug_mode: _logger.debug("Executing Action: %s", action_name) else: if action_name == "ask_user": _logger.info("❓ Asking user: %s", params.get("prompt")) elif action_name == "write_memory": _logger.info("Writing to memory partition: %s", params.get("partition")) if action.get("_parse_error"): error_message = "I encountered an error processing the request. Let me try a different approach." state["conversation_history"].append({"role": "assistant", "content": error_message}) return {**state, "agent_result": error_message} if "previous_actions" not in state: state["previous_actions"] = [] try: if action_name == "call_agent": agent_name = params["agent_name"] task = params["task"] agent_result = agents[agent_name].handle_task(task, state["memory"]) success_message = ( f"{agent_name} task completed and stored in the memory successfully" if agent_result else f"{agent_name} task failed" ) state["previous_actions"].append(f"Called agent {agent_name} with task: {task}") return {**state, "agent_result": success_message} if action_name == "write_memory": partition = params["partition"] raw_data = params["data"] # ``data`` is normally a dict; accept a JSON-encoded string too so # alternative SDK shapes work without special-casing the agents. if isinstance(raw_data, str): try: data = json.loads(raw_data) except json.JSONDecodeError as decode_err: raise ValueError( f"write_memory.data must be an object or JSON string; got: {raw_data!r}" ) from decode_err else: data = raw_data if not isinstance(data, dict): raise ValueError(f"write_memory.data must be an object, got {type(data).__name__}") updated_data = {**data, "last_updated": datetime.now().isoformat()} # Top-level partitions (user_profile, medical_history, …) are # merged so a partial update does not erase pre-existing keys. # Dotted paths (``flags_and_assessments.last_validation``) are # treated as a leaf assignment. if "." in partition: set_nested(state["memory"], partition, updated_data) else: update_memory_partition(state["memory"], partition, updated_data) state["previous_actions"].append(f"Wrote to memory partition: {partition}") return {**state, "agent_result": "Memory updated successfully"} if action_name == "compose_response": response_text = params.get("text") or params.get("response") if not response_text: raise ValueError("Missing 'text' or 'response' in params for compose_response") state["conversation_history"].append({"role": "assistant", "content": response_text}) state["previous_actions"].append("Composed response to user") return {**state, "agent_result": response_text} if action_name == "ask_user": prompt_text = params["prompt"] state["conversation_history"].append({"role": "assistant", "content": prompt_text}) state["previous_actions"].append(f"Asked user: {prompt_text}") return {**state, "agent_result": "User prompted for input"} state["previous_actions"].append(f"Executed {action_name} with params: {params}") return {**state, "agent_result": f"Unknown action: {action_name}"} except Exception as e: # noqa: BLE001 _logger.exception("Error executing %s", action_name) state["previous_actions"].append(f"Attempted {action_name} with params: {params}") return {**state, "agent_result": f"Error executing {action_name}: {str(e)}"} def setup_workflow(coach_agent, agents: Dict[str, Any], tools: Dict[str, Any], persistence_dir: str | None = None): workflow = StateGraph(NutritionState) workflow.add_node("coach", lambda state: coach_node(state, coach_agent)) workflow.add_node("execute_action", lambda state: execute_action_node(state, agents, tools)) workflow.set_entry_point("coach") workflow.add_edge("coach", "execute_action") workflow.add_conditional_edges( "execute_action", should_continue, {"execute_action": "coach", "end": END}, ) if persistence_dir: checkpointer = FileCheckpointSaver(persistence_dir) _logger.info("MAS workflow compiled with file-based persistence at %s.", persistence_dir) else: checkpointer = MemorySaver() _logger.info("MAS workflow compiled with in-memory persistence.") return workflow.compile(checkpointer=checkpointer)