File size: 4,376 Bytes
5374a2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from enum import Enum
from pydantic import Field
from typing import Union, Optional, List
from ..core.module import BaseModule
from ..core.message import Message, MessageType
from ..models.base_model import LLMOutputParser


class TrajectoryState(str, Enum):
    """
    Enum representing the status of a trajectory step.
    """
    COMPLETED = "COMPLETED"
    FAILED = "FAILED"


class TrajectoryStep(BaseModule):

    message: Message = None
    status: TrajectoryState
    error: Optional[str] = None


class Environment(BaseModule):

    """
    Responsible for storing and managing intermediate states of execution.
    """
    trajectory: List[TrajectoryStep] = Field(default_factory=list)
    task_execution_history: List[str] = Field(default_factory=list)
    execution_data: dict = Field(default_factory=dict)

    def update(self, message: Message, state: TrajectoryState = None, error: str = None, **kwargs):
        """
        Add a message to the shared memory and optionally to a specific task's message list.

        Args:
            message (Message): The message to be added.
            task_name (str, optional): The name of the task this message is related to. If None, the message is considered global.
        """
        state = state or TrajectoryState.COMPLETED
        step = TrajectoryStep(message=message, status=state, error=error)
        self.trajectory.append(step)
        self.update_task_execution_history(message=message)
        self.update_execution_data(message=message)
        
    def update_task_execution_history(self, message: Message):
        if message.wf_task is not None and message.msg_type in [MessageType.RESPONSE]:
            # if there are multiple actions for a task, only record once
            if not self.task_execution_history or message.wf_task != self.task_execution_history[-1]:
                self.task_execution_history.append(message.wf_task)

    def update_execution_data(self, message: Message):
        if isinstance(message.content, LLMOutputParser):
            data = message.content.get_structured_data()
            self.execution_data.update(data)
        if isinstance(message.content, dict):
            data = message.content
            self.execution_data.update(data)
    
    def update_execution_data_from_context_extraction(self, extracted_data: dict):
        for key, value in extracted_data.items():
            if key not in self.execution_data:
                self.execution_data[key] = value
    
    def get_task_messages(self, tasks: Union[str, List[str]], n: int = None, include_inputs: bool = False, **kwargs) -> List[Message]:
        """
        Retrieve all messages related to specified tasks

        Returns:
            List[Message]: A list of messages related to the task.
        """
        if isinstance(tasks, str):
            tasks = [tasks]
        message_list = [] 
        for step in self.trajectory:
            message = step.message
            if message.wf_task is not None and message.wf_task in tasks:
                message_list.append(message)
            if include_inputs and message.msg_type == MessageType.INPUT and message not in message_list:
                message_list.append(message)
        message_list = message_list if n is None else message_list[-n:]
        return message_list

    def get(self, n: int=None) -> List[Message]:
        """
        return the most recent n messages
        """
        assert n is None or n>=0, "n must be None or a positive int"
        all_messages = [step.message for step in self.trajectory]
        messages = all_messages if n is None else all_messages[-n:]
        return messages
    
    def get_last_executed_task(self) -> str:
        if self.task_execution_history:
            return self.task_execution_history[-1]
        return None
    
    def get_all_execution_data(self) -> dict:
        return self.execution_data
    
    def get_execution_data(self, params: Union[str, List[str]]) -> dict:
        if isinstance(params, str):
            params = [params]
        data = {}
        for param in params:
            if param not in self.execution_data:
                raise KeyError(f"Couldn't find execution data with key '{param}'. Available execution data: {list(self.execution_data.keys())}")
            data[param] = self.execution_data[param]
        return data