Spaces:
Running
Running
File size: 6,523 Bytes
9918f43 88245f7 9918f43 88245f7 9918f43 88245f7 bcd961e 88245f7 bcd961e 9918f43 88245f7 bcd961e 88245f7 bcd961e 88245f7 bcd961e 88245f7 bcd961e 88245f7 bcd961e 88245f7 bcd961e 88245f7 bcd961e 88245f7 bcd961e 88245f7 9918f43 bcd961e 9918f43 88245f7 bcd961e 88245f7 bcd961e 88245f7 bcd961e 88245f7 bcd961e 88245f7 bcd961e 88245f7 bcd961e 88245f7 bcd961e 88245f7 bcd961e 88245f7 bcd961e 88245f7 bcd961e 88245f7 bcd961e 88245f7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 | """LangGraph wiring for the Coach ↔ action loop.
Two-node graph: ``coach`` produces a single typed decision (``CoachDecision``)
and ``execute_action`` dispatches it to the right agent / tool / memory
write, then loops back. Terminates on ``compose_response`` or ``ask_user``
(or when ``max_turns`` is hit).
"""
from __future__ import annotations
import json
from datetime import datetime
from typing import Any, Dict
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, StateGraph
from config import get_settings
from logging_setup import get_logger
from state import NutritionState
from utils import FileCheckpointSaver, set_nested, update_memory_partition
_logger = get_logger("workflow")
def should_continue(state: NutritionState) -> str:
"""Edge predicate: stop on terminal action or when we hit max_turns."""
current = state.get("current_action") or {}
if current.get("action") in {"compose_response", "ask_user"}:
return "end"
if state["num_turns"] >= state["max_turns"]:
return "end"
return "execute_action"
def coach_node(state: NutritionState, coach_agent) -> NutritionState:
return coach_agent.handle_task(state)
def execute_action_node(state: NutritionState, agents: Dict[str, Any], tools: Dict[str, Any]) -> NutritionState:
action = state.get("current_action") or {}
action_name = action.get("action")
params = action.get("params", {}) or {}
if not action_name:
return state
settings = get_settings()
if settings.debug_mode:
_logger.debug("Executing Action: %s", action_name)
else:
if action_name == "ask_user":
_logger.info("❓ Asking user: %s", params.get("prompt"))
elif action_name == "write_memory":
_logger.info("Writing to memory partition: %s", params.get("partition"))
if action.get("_parse_error"):
error_message = "I encountered an error processing the request. Let me try a different approach."
state["conversation_history"].append({"role": "assistant", "content": error_message})
return {**state, "agent_result": error_message}
if "previous_actions" not in state:
state["previous_actions"] = []
try:
if action_name == "call_agent":
agent_name = params["agent_name"]
task = params["task"]
agent_result = agents[agent_name].handle_task(task, state["memory"])
success_message = (
f"{agent_name} task completed and stored in the memory successfully"
if agent_result
else f"{agent_name} task failed"
)
state["previous_actions"].append(f"Called agent {agent_name} with task: {task}")
return {**state, "agent_result": success_message}
if action_name == "write_memory":
partition = params["partition"]
raw_data = params["data"]
# ``data`` is normally a dict; accept a JSON-encoded string too so
# alternative SDK shapes work without special-casing the agents.
if isinstance(raw_data, str):
try:
data = json.loads(raw_data)
except json.JSONDecodeError as decode_err:
raise ValueError(
f"write_memory.data must be an object or JSON string; got: {raw_data!r}"
) from decode_err
else:
data = raw_data
if not isinstance(data, dict):
raise ValueError(f"write_memory.data must be an object, got {type(data).__name__}")
updated_data = {**data, "last_updated": datetime.now().isoformat()}
# Top-level partitions (user_profile, medical_history, …) are
# merged so a partial update does not erase pre-existing keys.
# Dotted paths (``flags_and_assessments.last_validation``) are
# treated as a leaf assignment.
if "." in partition:
set_nested(state["memory"], partition, updated_data)
else:
update_memory_partition(state["memory"], partition, updated_data)
state["previous_actions"].append(f"Wrote to memory partition: {partition}")
return {**state, "agent_result": "Memory updated successfully"}
if action_name == "compose_response":
response_text = params.get("text") or params.get("response")
if not response_text:
raise ValueError("Missing 'text' or 'response' in params for compose_response")
state["conversation_history"].append({"role": "assistant", "content": response_text})
state["previous_actions"].append("Composed response to user")
return {**state, "agent_result": response_text}
if action_name == "ask_user":
prompt_text = params["prompt"]
state["conversation_history"].append({"role": "assistant", "content": prompt_text})
state["previous_actions"].append(f"Asked user: {prompt_text}")
return {**state, "agent_result": "User prompted for input"}
state["previous_actions"].append(f"Executed {action_name} with params: {params}")
return {**state, "agent_result": f"Unknown action: {action_name}"}
except Exception as e: # noqa: BLE001
_logger.exception("Error executing %s", action_name)
state["previous_actions"].append(f"Attempted {action_name} with params: {params}")
return {**state, "agent_result": f"Error executing {action_name}: {str(e)}"}
def setup_workflow(coach_agent, agents: Dict[str, Any], tools: Dict[str, Any], persistence_dir: str | None = None):
workflow = StateGraph(NutritionState)
workflow.add_node("coach", lambda state: coach_node(state, coach_agent))
workflow.add_node("execute_action", lambda state: execute_action_node(state, agents, tools))
workflow.set_entry_point("coach")
workflow.add_edge("coach", "execute_action")
workflow.add_conditional_edges(
"execute_action",
should_continue,
{"execute_action": "coach", "end": END},
)
if persistence_dir:
checkpointer = FileCheckpointSaver(persistence_dir)
_logger.info("MAS workflow compiled with file-based persistence at %s.", persistence_dir)
else:
checkpointer = MemorySaver()
_logger.info("MAS workflow compiled with in-memory persistence.")
return workflow.compile(checkpointer=checkpointer)
|