iLOVE2D's picture
Upload 2846 files
5374a2d verified
from enum import Enum
from pydantic import Field
from typing import Union, Optional, List
from ..core.module import BaseModule
from ..core.message import Message, MessageType
from ..models.base_model import LLMOutputParser
class TrajectoryState(str, Enum):
"""
Enum representing the status of a trajectory step.
"""
COMPLETED = "COMPLETED"
FAILED = "FAILED"
class TrajectoryStep(BaseModule):
message: Message = None
status: TrajectoryState
error: Optional[str] = None
class Environment(BaseModule):
"""
Responsible for storing and managing intermediate states of execution.
"""
trajectory: List[TrajectoryStep] = Field(default_factory=list)
task_execution_history: List[str] = Field(default_factory=list)
execution_data: dict = Field(default_factory=dict)
def update(self, message: Message, state: TrajectoryState = None, error: str = None, **kwargs):
"""
Add a message to the shared memory and optionally to a specific task's message list.
Args:
message (Message): The message to be added.
task_name (str, optional): The name of the task this message is related to. If None, the message is considered global.
"""
state = state or TrajectoryState.COMPLETED
step = TrajectoryStep(message=message, status=state, error=error)
self.trajectory.append(step)
self.update_task_execution_history(message=message)
self.update_execution_data(message=message)
def update_task_execution_history(self, message: Message):
if message.wf_task is not None and message.msg_type in [MessageType.RESPONSE]:
# if there are multiple actions for a task, only record once
if not self.task_execution_history or message.wf_task != self.task_execution_history[-1]:
self.task_execution_history.append(message.wf_task)
def update_execution_data(self, message: Message):
if isinstance(message.content, LLMOutputParser):
data = message.content.get_structured_data()
self.execution_data.update(data)
if isinstance(message.content, dict):
data = message.content
self.execution_data.update(data)
def update_execution_data_from_context_extraction(self, extracted_data: dict):
for key, value in extracted_data.items():
if key not in self.execution_data:
self.execution_data[key] = value
def get_task_messages(self, tasks: Union[str, List[str]], n: int = None, include_inputs: bool = False, **kwargs) -> List[Message]:
"""
Retrieve all messages related to specified tasks
Returns:
List[Message]: A list of messages related to the task.
"""
if isinstance(tasks, str):
tasks = [tasks]
message_list = []
for step in self.trajectory:
message = step.message
if message.wf_task is not None and message.wf_task in tasks:
message_list.append(message)
if include_inputs and message.msg_type == MessageType.INPUT and message not in message_list:
message_list.append(message)
message_list = message_list if n is None else message_list[-n:]
return message_list
def get(self, n: int=None) -> List[Message]:
"""
return the most recent n messages
"""
assert n is None or n>=0, "n must be None or a positive int"
all_messages = [step.message for step in self.trajectory]
messages = all_messages if n is None else all_messages[-n:]
return messages
def get_last_executed_task(self) -> str:
if self.task_execution_history:
return self.task_execution_history[-1]
return None
def get_all_execution_data(self) -> dict:
return self.execution_data
def get_execution_data(self, params: Union[str, List[str]]) -> dict:
if isinstance(params, str):
params = [params]
data = {}
for param in params:
if param not in self.execution_data:
raise KeyError(f"Couldn't find execution data with key '{param}'. Available execution data: {list(self.execution_data.keys())}")
data[param] = self.execution_data[param]
return data