File size: 4,376 Bytes
5374a2d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
from enum import Enum
from pydantic import Field
from typing import Union, Optional, List
from ..core.module import BaseModule
from ..core.message import Message, MessageType
from ..models.base_model import LLMOutputParser
class TrajectoryState(str, Enum):
"""
Enum representing the status of a trajectory step.
"""
COMPLETED = "COMPLETED"
FAILED = "FAILED"
class TrajectoryStep(BaseModule):
message: Message = None
status: TrajectoryState
error: Optional[str] = None
class Environment(BaseModule):
"""
Responsible for storing and managing intermediate states of execution.
"""
trajectory: List[TrajectoryStep] = Field(default_factory=list)
task_execution_history: List[str] = Field(default_factory=list)
execution_data: dict = Field(default_factory=dict)
def update(self, message: Message, state: TrajectoryState = None, error: str = None, **kwargs):
"""
Add a message to the shared memory and optionally to a specific task's message list.
Args:
message (Message): The message to be added.
task_name (str, optional): The name of the task this message is related to. If None, the message is considered global.
"""
state = state or TrajectoryState.COMPLETED
step = TrajectoryStep(message=message, status=state, error=error)
self.trajectory.append(step)
self.update_task_execution_history(message=message)
self.update_execution_data(message=message)
def update_task_execution_history(self, message: Message):
if message.wf_task is not None and message.msg_type in [MessageType.RESPONSE]:
# if there are multiple actions for a task, only record once
if not self.task_execution_history or message.wf_task != self.task_execution_history[-1]:
self.task_execution_history.append(message.wf_task)
def update_execution_data(self, message: Message):
if isinstance(message.content, LLMOutputParser):
data = message.content.get_structured_data()
self.execution_data.update(data)
if isinstance(message.content, dict):
data = message.content
self.execution_data.update(data)
def update_execution_data_from_context_extraction(self, extracted_data: dict):
for key, value in extracted_data.items():
if key not in self.execution_data:
self.execution_data[key] = value
def get_task_messages(self, tasks: Union[str, List[str]], n: int = None, include_inputs: bool = False, **kwargs) -> List[Message]:
"""
Retrieve all messages related to specified tasks
Returns:
List[Message]: A list of messages related to the task.
"""
if isinstance(tasks, str):
tasks = [tasks]
message_list = []
for step in self.trajectory:
message = step.message
if message.wf_task is not None and message.wf_task in tasks:
message_list.append(message)
if include_inputs and message.msg_type == MessageType.INPUT and message not in message_list:
message_list.append(message)
message_list = message_list if n is None else message_list[-n:]
return message_list
def get(self, n: int=None) -> List[Message]:
"""
return the most recent n messages
"""
assert n is None or n>=0, "n must be None or a positive int"
all_messages = [step.message for step in self.trajectory]
messages = all_messages if n is None else all_messages[-n:]
return messages
def get_last_executed_task(self) -> str:
if self.task_execution_history:
return self.task_execution_history[-1]
return None
def get_all_execution_data(self) -> dict:
return self.execution_data
def get_execution_data(self, params: Union[str, List[str]]) -> dict:
if isinstance(params, str):
params = [params]
data = {}
for param in params:
if param not in self.execution_data:
raise KeyError(f"Couldn't find execution data with key '{param}'. Available execution data: {list(self.execution_data.keys())}")
data[param] = self.execution_data[param]
return data
|