File size: 6,523 Bytes
9918f43
88245f7
9918f43
 
 
 
88245f7
 
 
 
9918f43
88245f7
 
 
bcd961e
88245f7
 
 
 
bcd961e
9918f43
88245f7
 
 
bcd961e
 
88245f7
 
 
bcd961e
 
 
 
 
88245f7
bcd961e
 
 
88245f7
 
 
 
 
 
 
bcd961e
88245f7
 
 
 
 
 
 
 
 
 
bcd961e
 
 
 
 
88245f7
 
bcd961e
 
88245f7
 
 
bcd961e
88245f7
 
 
 
 
 
bcd961e
 
88245f7
 
9918f43
 
 
 
 
 
 
 
 
 
 
 
 
 
bcd961e
9918f43
 
 
 
 
 
 
 
88245f7
bcd961e
 
88245f7
 
bcd961e
 
 
88245f7
bcd961e
 
88245f7
 
bcd961e
88245f7
bcd961e
 
88245f7
 
bcd961e
88245f7
 
 
 
bcd961e
88245f7
 
bcd961e
 
 
 
 
88245f7
 
 
 
 
 
bcd961e
 
88245f7
bcd961e
 
88245f7
bcd961e
88245f7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
"""LangGraph wiring for the Coach ↔ action loop.

Two-node graph: ``coach`` produces a single typed decision (``CoachDecision``)
and ``execute_action`` dispatches it to the right agent / tool / memory
write, then loops back. Terminates on ``compose_response`` or ``ask_user``
(or when ``max_turns`` is hit).
"""

from __future__ import annotations

import json
from datetime import datetime
from typing import Any, Dict

from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, StateGraph

from config import get_settings
from logging_setup import get_logger
from state import NutritionState
from utils import FileCheckpointSaver, set_nested, update_memory_partition

_logger = get_logger("workflow")


def should_continue(state: NutritionState) -> str:
    """Edge predicate: stop on terminal action or when we hit max_turns."""
    current = state.get("current_action") or {}
    if current.get("action") in {"compose_response", "ask_user"}:
        return "end"
    if state["num_turns"] >= state["max_turns"]:
        return "end"
    return "execute_action"


def coach_node(state: NutritionState, coach_agent) -> NutritionState:
    return coach_agent.handle_task(state)


def execute_action_node(state: NutritionState, agents: Dict[str, Any], tools: Dict[str, Any]) -> NutritionState:
    action = state.get("current_action") or {}
    action_name = action.get("action")
    params = action.get("params", {}) or {}

    if not action_name:
        return state

    settings = get_settings()
    if settings.debug_mode:
        _logger.debug("Executing Action: %s", action_name)
    else:
        if action_name == "ask_user":
            _logger.info("❓ Asking user: %s", params.get("prompt"))
        elif action_name == "write_memory":
            _logger.info("Writing to memory partition: %s", params.get("partition"))

    if action.get("_parse_error"):
        error_message = "I encountered an error processing the request. Let me try a different approach."
        state["conversation_history"].append({"role": "assistant", "content": error_message})
        return {**state, "agent_result": error_message}

    if "previous_actions" not in state:
        state["previous_actions"] = []

    try:
        if action_name == "call_agent":
            agent_name = params["agent_name"]
            task = params["task"]
            agent_result = agents[agent_name].handle_task(task, state["memory"])
            success_message = (
                f"{agent_name} task completed and stored in the memory successfully"
                if agent_result
                else f"{agent_name} task failed"
            )
            state["previous_actions"].append(f"Called agent {agent_name} with task: {task}")
            return {**state, "agent_result": success_message}

        if action_name == "write_memory":
            partition = params["partition"]
            raw_data = params["data"]
            # ``data`` is normally a dict; accept a JSON-encoded string too so
            # alternative SDK shapes work without special-casing the agents.
            if isinstance(raw_data, str):
                try:
                    data = json.loads(raw_data)
                except json.JSONDecodeError as decode_err:
                    raise ValueError(
                        f"write_memory.data must be an object or JSON string; got: {raw_data!r}"
                    ) from decode_err
            else:
                data = raw_data
            if not isinstance(data, dict):
                raise ValueError(f"write_memory.data must be an object, got {type(data).__name__}")
            updated_data = {**data, "last_updated": datetime.now().isoformat()}
            # Top-level partitions (user_profile, medical_history, …) are
            # merged so a partial update does not erase pre-existing keys.
            # Dotted paths (``flags_and_assessments.last_validation``) are
            # treated as a leaf assignment.
            if "." in partition:
                set_nested(state["memory"], partition, updated_data)
            else:
                update_memory_partition(state["memory"], partition, updated_data)
            state["previous_actions"].append(f"Wrote to memory partition: {partition}")
            return {**state, "agent_result": "Memory updated successfully"}

        if action_name == "compose_response":
            response_text = params.get("text") or params.get("response")
            if not response_text:
                raise ValueError("Missing 'text' or 'response' in params for compose_response")
            state["conversation_history"].append({"role": "assistant", "content": response_text})
            state["previous_actions"].append("Composed response to user")
            return {**state, "agent_result": response_text}

        if action_name == "ask_user":
            prompt_text = params["prompt"]
            state["conversation_history"].append({"role": "assistant", "content": prompt_text})
            state["previous_actions"].append(f"Asked user: {prompt_text}")
            return {**state, "agent_result": "User prompted for input"}

        state["previous_actions"].append(f"Executed {action_name} with params: {params}")
        return {**state, "agent_result": f"Unknown action: {action_name}"}

    except Exception as e:  # noqa: BLE001
        _logger.exception("Error executing %s", action_name)
        state["previous_actions"].append(f"Attempted {action_name} with params: {params}")
        return {**state, "agent_result": f"Error executing {action_name}: {str(e)}"}


def setup_workflow(coach_agent, agents: Dict[str, Any], tools: Dict[str, Any], persistence_dir: str | None = None):
    workflow = StateGraph(NutritionState)
    workflow.add_node("coach", lambda state: coach_node(state, coach_agent))
    workflow.add_node("execute_action", lambda state: execute_action_node(state, agents, tools))
    workflow.set_entry_point("coach")
    workflow.add_edge("coach", "execute_action")
    workflow.add_conditional_edges(
        "execute_action",
        should_continue,
        {"execute_action": "coach", "end": END},
    )

    if persistence_dir:
        checkpointer = FileCheckpointSaver(persistence_dir)
        _logger.info("MAS workflow compiled with file-based persistence at %s.", persistence_dir)
    else:
        checkpointer = MemorySaver()
        _logger.info("MAS workflow compiled with in-memory persistence.")

    return workflow.compile(checkpointer=checkpointer)