File size: 20,982 Bytes
5374a2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
import inspect
import asyncio
from copy import deepcopy
from pydantic import Field, create_model
from typing import Optional, List
from ..core.logging import logger
from ..core.module import BaseModule
from ..core.message import Message, MessageType
from ..core.module_utils import generate_id
from ..models.base_model import BaseLLM
from ..agents.agent import Agent
from ..agents.agent_manager import AgentManager, AgentState
from ..storages.base import StorageHandler
from .environment import Environment, TrajectoryState
from .workflow_manager import WorkFlowManager, NextAction
from .workflow_graph import WorkFlowNode, WorkFlowGraph
from .action_graph import ActionGraph
from ..hitl import HITLManager, HITLBaseAgent
from ..utils.utils import generate_dynamic_class_name
from ..actions import ActionInput, ActionOutput

class WorkFlow(BaseModule):

    graph: WorkFlowGraph
    llm: Optional[BaseLLM] = None
    agent_manager: AgentManager = Field(default=None, description="Responsible for managing agents")
    workflow_manager: WorkFlowManager = Field(default=None, description="Responsible for task and action scheduling for workflow execution")
    environment: Environment = Field(default_factory=Environment)
    storage_handler: StorageHandler = None
    workflow_id: str = Field(default_factory=generate_id)
    version: int = 0 
    max_execution_steps: int = Field(default=5, description="The maximum number of steps to complete a subtask (node) in the workflow")
    hitl_manager: HITLManager = Field(default=None, description="Responsible for HITL work management")

    def init_module(self):
        if self.workflow_manager is None:
            if self.llm is None:
                raise ValueError("Must provide `llm` when `workflow_manager` is None")
            self.workflow_manager = WorkFlowManager(llm=self.llm)
        if self.agent_manager is None:
            logger.warning("agent_manager is NoneType when initializing a WorkFlow instance")

    def execute(self, inputs: dict = {}, **kwargs) -> str:
        """
        Synchronous wrapper for async_execute. Creates a new event loop and runs the async method.
        
        Args:
            inputs: Dictionary of inputs for workflow execution
            **kwargs (Any): Additional keyword arguments
            
        Returns:
            str: The output of the workflow execution
        """
        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)
        try:
            return loop.run_until_complete(self.async_execute(inputs, **kwargs))
        finally:
            loop.close()

    async def async_execute(self, inputs: dict = {}, **kwargs) -> str:
        """
        Asynchronously execute the workflow.
        
        Args:
            inputs: Dictionary of inputs for workflow execution
            **kwargs (Any): Additional keyword arguments
            
        Returns:
            str: The output of the workflow execution
        """
        goal = self.graph.goal
        # inputs.update({"goal": goal})
        inputs = self._prepare_inputs(inputs)

        # prepare for hitl functionalities
        if hasattr(self, "hitl_manager") and (self.hitl_manager is not None):
            self._prepare_hitl()

        # check the inputs and outputs of the task 
        self._validate_workflow_structure(inputs=inputs, **kwargs)
        inp_message = Message(content=inputs, msg_type=MessageType.INPUT, wf_goal=goal)
        self.environment.update(message=inp_message, state=TrajectoryState.COMPLETED)

        failed = False
        error_message = None
        while not self.graph.is_complete and not failed:
            try:
                task: WorkFlowNode = await self.get_next_task()
                if task is None:
                    break
                logger.info(f"Executing subtask: {task.name}")
                await self.execute_task(task=task)
            except Exception as e:
                failed = True
                error_message = Message(
                    content=f"An Error occurs when executing the workflow: {e}",
                    msg_type=MessageType.ERROR, 
                    wf_goal=goal
                )
                self.environment.update(message=error_message, state=TrajectoryState.FAILED, error=str(e))
        
        if failed:
            logger.error(error_message.content)
            return "Workflow Execution Failed"
        
        logger.info("Extracting WorkFlow Output ...")
        output: str = await self.workflow_manager.extract_output(graph=self.graph, env=self.environment)
        return output
    
    def _prepare_inputs(self, inputs: dict) -> dict:
        """
        Prepare the inputs for the workflow execution. Mainly determine whether the goal should be added to the inputs.
        """
        initial_node_names = self.graph.find_initial_nodes()
        initial_node_required_inputs = set()
        for initial_node_name in initial_node_names:
            initial_node = self.graph.get_node(initial_node_name)
            if initial_node.inputs:
                initial_node_required_inputs.update([inp.name for inp in initial_node.inputs if inp.required])
        if "goal" in initial_node_required_inputs and "goal" not in inputs:
            inputs.update({"goal": self.graph.goal})
            
        return inputs 
    
    async def get_next_task(self) -> WorkFlowNode:
        task_execution_history = " -> ".join(self.environment.task_execution_history)
        if not task_execution_history:
            task_execution_history = "None"
        logger.info(f"Task Execution Trajectory: {task_execution_history}. Scheduling next subtask ...")
        task: WorkFlowNode = await self.workflow_manager.schedule_next_task(graph=self.graph, env=self.environment)
        logger.info(f"The next subtask to be executed is: {task.name}")
        return task
        
    async def execute_task(self, task: WorkFlowNode):
        """
        Asynchronously execute a workflow task.
        
        Args:
            task: The workflow node to execute
        """
        last_executed_task = self.environment.get_last_executed_task()
        self.graph.step(source_node=last_executed_task, target_node=task)
        next_action: NextAction = await self.workflow_manager.schedule_next_action(
            goal=self.graph.goal,
            task=task, 
            agent_manager=self.agent_manager, 
            env=self.environment
        )
        if next_action.action_graph is not None:
            await self._async_execute_task_by_action_graph(task=task, next_action=next_action)
        else:
            await self._async_execute_task_by_agents(task=task, next_action=next_action)
        self.graph.completed(node=task)

    async def _async_execute_task_by_action_graph(self, task: WorkFlowNode, next_action: NextAction):
        """
        Asynchronously execute a task using an action graph.
        
        Args:
            task: The workflow node to execute
            next_action: The next action to perform with its action graph
        """
        action_graph: ActionGraph = next_action.action_graph
        async_execute_source = inspect.getsource(action_graph.async_execute)
        if "NotImplementedError" in async_execute_source:
            execute_function = action_graph.execute
            async_execute = False
        else:
            execute_function = action_graph.async_execute
            async_execute = True
        # execute_signature = inspect.signature(type(action_graph).async_execute)
        execute_signature = inspect.signature(execute_function)
        execute_params = {}
        action_input_data = self.environment.get_all_execution_data() 
        for param_name, param_obj in execute_signature.parameters.items():
            if param_name in ["self", "args", "kwargs"]:
                continue
            # execute_params.append(param)
            if param_name in action_input_data:
                execute_params[param_name] = action_input_data[param_name]
            elif param_obj.default is not param_obj.empty:
                execute_params[param_name] = param_obj.default 
            else:
                execute_params[param_name] = None
        # action_input_data = self.environment.get_all_execution_data()
        # execute_inputs = {param: action_input_data.get(param, "") for param in execute_params}
        # action_graph_output: dict = await action_graph.async_execute(**execute_inputs)
        if async_execute:
            action_graph_output: dict = await action_graph.async_execute(**execute_params)
        else:
            action_graph_output: dict = action_graph.execute(**execute_params)

        message = Message(
            content=action_graph_output, action=action_graph.name, msg_type=MessageType.RESPONSE,
            wf_goal=self.graph.goal, wf_task=task.name, wf_task_desc=task.description
        )
        self.environment.update(message=message, state=TrajectoryState.COMPLETED)
    
    async def _async_execute_task_by_agents(self, task: WorkFlowNode, next_action: NextAction):
        """
        Asynchronously execute a task using agents.
        
        Args:
            task: The workflow node to execute
            next_action: The next action to perform using agents
        """
        num_execution = 0 
        while next_action:
            if num_execution >= self.max_execution_steps:
                raise ValueError(
                    f"Maximum number of steps ({self.max_execution_steps}) reached when executing {task.name}. "
                    "Please check the workflow structure (e.g., inputs and outputs of the nodes and the agents) "
                    "or increase the `max_execution_steps` parameter."
                )
            agent: Agent = self.agent_manager.get_agent(agent_name=next_action.agent)
            if not self.agent_manager.wait_for_agent_available(agent_name=agent.name, timeout=300):
                raise TimeoutError(f"Timeout waiting for agent {agent.name} to become available")
            self.agent_manager.set_agent_state(agent_name=next_action.agent, new_state=AgentState.RUNNING)
            try:
                # message = await agent.async_execute(
                #     action_name=next_action.action,
                #     action_input_data=self.environment.get_all_execution_data(),
                #     return_msg_type=MessageType.RESPONSE, 
                #     wf_goal=self.graph.goal,
                #     wf_task=task.name, 
                #     wf_task_desc=task.description
                # )
                message = await self._async_execute_action(task=task, agent=agent, next_action=next_action)
                self.environment.update(message=message, state=TrajectoryState.COMPLETED)
            finally:
                self.agent_manager.set_agent_state(agent_name=next_action.agent, new_state=AgentState.AVAILABLE)
            if self.is_task_completed(task=task):
                break
            next_action: NextAction = await self.workflow_manager.schedule_next_action(
                goal=self.graph.goal,
                task=task,
                agent_manager=self.agent_manager, 
                env=self.environment
            )
            num_execution += 1 

    async def _async_execute_action(self, task: WorkFlowNode, agent: Agent, next_action: NextAction) -> Message:
        """
        Asynchronously execute an action using an agent.
        """
        action_name = next_action.action
        all_execution_data = self.environment.get_all_execution_data()

        # hitl part
        if hasattr(self, "hitl_manager") and (self.hitl_manager is not None):
            hitl_manager = self.hitl_manager
        else:
            hitl_manager = None

        action_inputs_format = agent.get_action(action_name).inputs_format
        action_input_data = {} 
        if action_inputs_format:
            for input_name in action_inputs_format.get_attrs():
                if input_name in all_execution_data:
                    action_input_data[input_name] = all_execution_data[input_name]
            action_required_input_names = action_inputs_format.get_required_input_names()
            if not all(inp in action_input_data for inp in action_required_input_names):
                # could not find all the required inputs in the execution data
                predecessors = self.graph.get_node_predecessors(node=task)
                predecessors_messages = self.environment.get_task_messages(
                    tasks=predecessors + [task.name], include_inputs=True
                )
                predecessors_messages = [
                    message for message in predecessors_messages 
                    if message.msg_type in [MessageType.INPUT, MessageType.RESPONSE]
                ]
                message, extracted_data = await agent.async_execute(
                    action_name=action_name, 
                    msgs=predecessors_messages,
                    return_msg_type=MessageType.RESPONSE,
                    return_action_input_data=True,
                    wf_goal=self.graph.goal,
                    wf_task=task.name,
                    wf_task_desc=task.description,
                    hitl_manager=hitl_manager
                )
                self.environment.update_execution_data_from_context_extraction(extracted_data)
                return message
        
        message = await agent.async_execute(
            action_name=action_name,
            action_input_data=action_input_data,
            return_msg_type=MessageType.RESPONSE,
            wf_goal=self.graph.goal,
            wf_task=task.name,
            wf_task_desc=task.description,
            hitl_manager=hitl_manager
        )
        return message
    
    def is_task_completed(self, task: WorkFlowNode) -> bool:
        task_outputs = [output.name for output in task.outputs]
        current_execution_data = self.environment.get_all_execution_data()
        return all(output in current_execution_data for output in task_outputs)
    
    def _validate_workflow_structure(self, inputs: dict, **kwargs):

        # check the inputs and outputs of the nodes 
        input_names = set(inputs.keys())
        for node in self.graph.nodes:
            node_input_names = deepcopy(input_names)
            is_initial_node = True
            for name in self.graph.get_node_predecessors(node):
                is_initial_node = False 
                predecessor = self.graph.get_node(name)
                node_input_names.update(predecessor.get_output_names())
            node_required_input_names = set(node.get_input_names(required=True))
            if not all(input_name in node_input_names for input_name in node_required_input_names):
                missing_required_inputs = node_required_input_names - node_input_names 
                if is_initial_node:
                    raise ValueError(
                        f"The initial node '{node.name}' is missing required inputs: {list(missing_required_inputs)}. "
                        "You should provide these inputs by specifying the `inputs={'input_name': 'input_value'}` parameter in the `execute` method, "
                        "or return the valid inputs in the `collate_func` when using `Evaluator`."
                    )
                else:
                    raise ValueError(
                        f"The node '{node.name}' is missing required inputs: {list(missing_required_inputs)}. "
                        f"You may need to check the `inputs` and `outputs` of the nodes to ensure that all the required inputs of node '{node.name}' are provided "
                        f"by either its predecessors or the `inputs` parameter in the `execute` method."
                    )
        
        for node in self.graph.nodes:
            for agent in node.agents:
                if hasattr(agent, "forbidden_in_workflow") and (agent.forbidden_in_workflow):
                    raise ValueError(f"The Agent of class {agent.__class__} is forbidden to be used in the workflow.")

    def _prepare_single_hitl_agent(self, agent: Agent, node: WorkFlowNode):
        """
        add complementary information and settings which need dynamically setting up to a single hitl agent
        For example, the `inputs_format` attribute, this needs a dynamical setting up.
        Up to Now, we only consider a HITL agent must be the only agent in its WorkFlowNode instance, this condition may be changed in the future
        Args:
            agent (Agent): a single HITL Agent instance 
            node (WorkFlowNode): a single WorkFlowNode instane which contains exactly the agent of previous param.
        """
        predecessors: List[str] = self.graph.get_node_predecessors(node)
        hitl_action = None
        for action in agent.actions:
            if (action.inputs_format) and (action.outputs_format):
                continue
            elif hasattr(action, "interaction_type"):
                hitl_action = action
                break
        if not hitl_action:
            raise ValueError(f"Can not find a HITL action in agent {agent}")

        hitl_inputs_data_fields = {}

        # set up inputs_format and outputs_format
        for predecessor in predecessors:
            predecessor_node = self.graph.get_node(predecessor)
            for param in predecessor_node.outputs:
                if param.required:
                    hitl_inputs_data_fields[param.name] = (str, Field(description=param.description))
                else:
                    hitl_inputs_data_fields[param.name] = (Optional[str], Field(description=param.description))
        inputs_format = create_model(
            agent._get_unique_class_name(
                generate_dynamic_class_name(hitl_action.class_name+" action_input")
            ),
            **(hitl_inputs_data_fields or {}),
            __base__= ActionInput
        )
        
        successors: List[str] = self.graph.get_node_children(node)
        hitl_outputs_data_fields = {}
        if successors == []:
            # hitl node as the ending node, not allowed for now
            raise ValueError("WorkFlowNode with a HITL Agent can not be set as the ending node.")
        for successor in successors:
            successor_node = self.graph.get_node(successor)
            for param in successor_node.inputs:
                if param.required:
                    hitl_outputs_data_fields[param.name] = (str, Field(description=param.description))
                else:
                    hitl_outputs_data_fields[param.name] = (Optional[str], Field(description=param.description))
        outputs_format = create_model(
            agent._get_unique_class_name(
                generate_dynamic_class_name(hitl_action.class_name+" action_output")
            ),
            **(hitl_outputs_data_fields or {}),
            __base__=ActionOutput
        )
        hitl_action.inputs_format = inputs_format
        hitl_action.outputs_format = outputs_format

        ## check hitl data field mapping
        if self.hitl_manager.hitl_input_output_mapping is None:
            raise ValueError("hitl_input_output_mapping attribute missing in HITLManager instance.")
        return

    def _prepare_hitl(self):
        """
        Prepare hitl settings before executing the WorkFlow
        """
        if self.hitl_manager is None:
            return
        hitl_agents: List[Agent]  = []
        node_with_hitl_agents = []
        for node in self.graph.nodes:
            agents = node.agents
            found_hitl_agent = False
            for agent in agents:
                # transfer to Agent instance
                if isinstance(agent, dict):
                    agent = self.agent_manager.get_agent(self.agent_manager.get_agent_name(agent))
                elif isinstance(agent, str):
                    agent = self.agent_manager.get_agent(agent)
                elif isinstance(agent, Agent):
                    pass
                # judgement
                if isinstance(agent, HITLBaseAgent):
                    found_hitl_agent = True
                    if agent not in hitl_agents:
                        hitl_agents.append(agent)
            if found_hitl_agent:
                node_with_hitl_agents.append(node)
                found_hitl_agent = False

        # Up to Now, we only consider a HITL agent must be the only agent in its WorkFlowNode instance, this condition may be changed in the future
        if len(hitl_agents) != len(node_with_hitl_agents):
            raise ValueError("Incorrect WorkFlowNode definition: A HITL Agent must be the only agent in its WorkFlowNode instance")

        # add complementary information and settings which need dynamically setting up to hitl agents
        for agent, node in zip(hitl_agents, node_with_hitl_agents):
            self._prepare_single_hitl_agent(agent, node)

        return