|
|
from enum import Enum |
|
|
from pydantic import Field |
|
|
from typing import Union, Optional, List |
|
|
from ..core.module import BaseModule |
|
|
from ..core.message import Message, MessageType |
|
|
from ..models.base_model import LLMOutputParser |
|
|
|
|
|
|
|
|
class TrajectoryState(str, Enum): |
|
|
""" |
|
|
Enum representing the status of a trajectory step. |
|
|
""" |
|
|
COMPLETED = "COMPLETED" |
|
|
FAILED = "FAILED" |
|
|
|
|
|
|
|
|
class TrajectoryStep(BaseModule): |
|
|
|
|
|
message: Message = None |
|
|
status: TrajectoryState |
|
|
error: Optional[str] = None |
|
|
|
|
|
|
|
|
class Environment(BaseModule): |
|
|
|
|
|
""" |
|
|
Responsible for storing and managing intermediate states of execution. |
|
|
""" |
|
|
trajectory: List[TrajectoryStep] = Field(default_factory=list) |
|
|
task_execution_history: List[str] = Field(default_factory=list) |
|
|
execution_data: dict = Field(default_factory=dict) |
|
|
|
|
|
def update(self, message: Message, state: TrajectoryState = None, error: str = None, **kwargs): |
|
|
""" |
|
|
Add a message to the shared memory and optionally to a specific task's message list. |
|
|
|
|
|
Args: |
|
|
message (Message): The message to be added. |
|
|
task_name (str, optional): The name of the task this message is related to. If None, the message is considered global. |
|
|
""" |
|
|
state = state or TrajectoryState.COMPLETED |
|
|
step = TrajectoryStep(message=message, status=state, error=error) |
|
|
self.trajectory.append(step) |
|
|
self.update_task_execution_history(message=message) |
|
|
self.update_execution_data(message=message) |
|
|
|
|
|
def update_task_execution_history(self, message: Message): |
|
|
if message.wf_task is not None and message.msg_type in [MessageType.RESPONSE]: |
|
|
|
|
|
if not self.task_execution_history or message.wf_task != self.task_execution_history[-1]: |
|
|
self.task_execution_history.append(message.wf_task) |
|
|
|
|
|
def update_execution_data(self, message: Message): |
|
|
if isinstance(message.content, LLMOutputParser): |
|
|
data = message.content.get_structured_data() |
|
|
self.execution_data.update(data) |
|
|
if isinstance(message.content, dict): |
|
|
data = message.content |
|
|
self.execution_data.update(data) |
|
|
|
|
|
def update_execution_data_from_context_extraction(self, extracted_data: dict): |
|
|
for key, value in extracted_data.items(): |
|
|
if key not in self.execution_data: |
|
|
self.execution_data[key] = value |
|
|
|
|
|
def get_task_messages(self, tasks: Union[str, List[str]], n: int = None, include_inputs: bool = False, **kwargs) -> List[Message]: |
|
|
""" |
|
|
Retrieve all messages related to specified tasks |
|
|
|
|
|
Returns: |
|
|
List[Message]: A list of messages related to the task. |
|
|
""" |
|
|
if isinstance(tasks, str): |
|
|
tasks = [tasks] |
|
|
message_list = [] |
|
|
for step in self.trajectory: |
|
|
message = step.message |
|
|
if message.wf_task is not None and message.wf_task in tasks: |
|
|
message_list.append(message) |
|
|
if include_inputs and message.msg_type == MessageType.INPUT and message not in message_list: |
|
|
message_list.append(message) |
|
|
message_list = message_list if n is None else message_list[-n:] |
|
|
return message_list |
|
|
|
|
|
def get(self, n: int=None) -> List[Message]: |
|
|
""" |
|
|
return the most recent n messages |
|
|
""" |
|
|
assert n is None or n>=0, "n must be None or a positive int" |
|
|
all_messages = [step.message for step in self.trajectory] |
|
|
messages = all_messages if n is None else all_messages[-n:] |
|
|
return messages |
|
|
|
|
|
def get_last_executed_task(self) -> str: |
|
|
if self.task_execution_history: |
|
|
return self.task_execution_history[-1] |
|
|
return None |
|
|
|
|
|
def get_all_execution_data(self) -> dict: |
|
|
return self.execution_data |
|
|
|
|
|
def get_execution_data(self, params: Union[str, List[str]]) -> dict: |
|
|
if isinstance(params, str): |
|
|
params = [params] |
|
|
data = {} |
|
|
for param in params: |
|
|
if param not in self.execution_data: |
|
|
raise KeyError(f"Couldn't find execution data with key '{param}'. Available execution data: {list(self.execution_data.keys())}") |
|
|
data[param] = self.execution_data[param] |
|
|
return data |
|
|
|
|
|
|