Spaces:
Sleeping
Sleeping
| from dataclasses import asdict, dataclass | |
| from logging import getLogger | |
| from typing import TYPE_CHECKING, Any, Dict, List, TypedDict, Union | |
| from smolagents.models import ChatMessage, MessageRole | |
| from smolagents.monitoring import AgentLogger, LogLevel | |
| from smolagents.utils import AgentError, make_json_serializable | |
| if TYPE_CHECKING: | |
| from smolagents.models import ChatMessage | |
| from smolagents.monitoring import AgentLogger | |
| logger = getLogger(__name__) | |
| class Message(TypedDict): | |
| role: MessageRole | |
| content: str | list[dict] | |
| class ToolCall: | |
| name: str | |
| arguments: Any | |
| id: str | |
| def dict(self): | |
| return { | |
| "id": self.id, | |
| "type": "function", | |
| "function": { | |
| "name": self.name, | |
| "arguments": make_json_serializable(self.arguments), | |
| }, | |
| } | |
| class MemoryStep: | |
| def dict(self): | |
| return asdict(self) | |
| def to_messages(self, **kwargs) -> List[Dict[str, Any]]: | |
| raise NotImplementedError | |
| class ActionStep(MemoryStep): | |
| model_input_messages: List[Message] | None = None | |
| tool_calls: List[ToolCall] | None = None | |
| start_time: float | None = None | |
| end_time: float | None = None | |
| step_number: int | None = None | |
| error: AgentError | None = None | |
| duration: float | None = None | |
| model_output_message: ChatMessage = None | |
| model_output: str | None = None | |
| observations: str | None = None | |
| observations_images: List[str] | None = None | |
| action_output: Any = None | |
| def dict(self): | |
| # We overwrite the method to parse the tool_calls and action_output manually | |
| return { | |
| "model_input_messages": self.model_input_messages, | |
| "tool_calls": [tc.dict() for tc in self.tool_calls] if self.tool_calls else [], | |
| "start_time": self.start_time, | |
| "end_time": self.end_time, | |
| "step": self.step_number, | |
| "error": self.error.dict() if self.error else None, | |
| "duration": self.duration, | |
| "model_output_message": self.model_output_message, | |
| "model_output": self.model_output, | |
| "observations": self.observations, | |
| "action_output": make_json_serializable(self.action_output), | |
| } | |
| def to_messages(self, summary_mode: bool = False, show_model_input_messages: bool = False) -> List[Message]: | |
| messages = [] | |
| if self.model_input_messages is not None and show_model_input_messages: | |
| messages.append(Message(role=MessageRole.SYSTEM, content=self.model_input_messages)) | |
| if self.model_output is not None and not summary_mode: | |
| messages.append( | |
| Message(role=MessageRole.ASSISTANT, content=[{"type": "text", "text": self.model_output.strip()}]) | |
| ) | |
| if self.tool_calls is not None: | |
| messages.append( | |
| Message( | |
| role=MessageRole.ASSISTANT, | |
| content=[ | |
| { | |
| "type": "text", | |
| "text": "Calling tools:\n" + str([tc.dict() for tc in self.tool_calls]), | |
| } | |
| ], | |
| ) | |
| ) | |
| if self.observations is not None: | |
| messages.append( | |
| Message( | |
| role=MessageRole.TOOL_RESPONSE, | |
| content=[ | |
| { | |
| "type": "text", | |
| "text": f"Call id: {self.tool_calls[0].id}\nObservation:\n{self.observations}", | |
| } | |
| ], | |
| ) | |
| ) | |
| if self.error is not None: | |
| error_message = ( | |
| "Error:\n" | |
| + str(self.error) | |
| + "\nNow let's retry: take care not to repeat previous errors! If you have retried several times, try a completely different approach.\n" | |
| ) | |
| message_content = f"Call id: {self.tool_calls[0].id}\n" if self.tool_calls else "" | |
| message_content += error_message | |
| messages.append( | |
| Message(role=MessageRole.TOOL_RESPONSE, content=[{"type": "text", "text": message_content}]) | |
| ) | |
| if self.observations_images: | |
| messages.append( | |
| Message( | |
| role=MessageRole.USER, | |
| content=[{"type": "text", "text": "Here are the observed images:"}] | |
| + [ | |
| { | |
| "type": "image", | |
| "image": image, | |
| } | |
| for image in self.observations_images | |
| ], | |
| ) | |
| ) | |
| return messages | |
| class PlanningStep(MemoryStep): | |
| model_input_messages: List[Message] | |
| model_output_message_facts: ChatMessage | |
| facts: str | |
| model_output_message_plan: ChatMessage | |
| plan: str | |
| def to_messages(self, summary_mode: bool, **kwargs) -> List[Message]: | |
| messages = [] | |
| messages.append( | |
| Message( | |
| role=MessageRole.ASSISTANT, content=[{"type": "text", "text": f"[FACTS LIST]:\n{self.facts.strip()}"}] | |
| ) | |
| ) | |
| if not summary_mode: # This step is not shown to a model writing a plan to avoid influencing the new plan | |
| messages.append( | |
| Message( | |
| role=MessageRole.ASSISTANT, content=[{"type": "text", "text": f"[PLAN]:\n{self.plan.strip()}"}] | |
| ) | |
| ) | |
| return messages | |
| class TaskStep(MemoryStep): | |
| task: str | |
| task_images: List[str] | None = None | |
| def to_messages(self, summary_mode: bool = False, **kwargs) -> List[Message]: | |
| content = [{"type": "text", "text": f"New task:\n{self.task}"}] | |
| if self.task_images: | |
| for image in self.task_images: | |
| content.append({"type": "image", "image": image}) | |
| return [Message(role=MessageRole.USER, content=content)] | |
| class SystemPromptStep(MemoryStep): | |
| system_prompt: str | |
| def to_messages(self, summary_mode: bool = False, **kwargs) -> List[Message]: | |
| if summary_mode: | |
| return [] | |
| return [Message(role=MessageRole.SYSTEM, content=[{"type": "text", "text": self.system_prompt}])] | |
| class AgentMemory: | |
| def __init__(self, system_prompt: str): | |
| self.system_prompt = SystemPromptStep(system_prompt=system_prompt) | |
| self.steps: List[Union[TaskStep, ActionStep, PlanningStep]] = [] | |
| def reset(self): | |
| self.steps = [] | |
| def get_succinct_steps(self) -> list[dict]: | |
| return [ | |
| {key: value for key, value in step.dict().items() if key != "model_input_messages"} for step in self.steps | |
| ] | |
| def get_full_steps(self) -> list[dict]: | |
| return [step.dict() for step in self.steps] | |
| def replay(self, logger: AgentLogger, detailed: bool = False): | |
| """Prints a pretty replay of the agent's steps. | |
| Args: | |
| logger (AgentLogger): The logger to print replay logs to. | |
| detailed (bool, optional): If True, also displays the memory at each step. Defaults to False. | |
| Careful: will increase log length exponentially. Use only for debugging. | |
| """ | |
| logger.console.log("Replaying the agent's steps:") | |
| for step in self.steps: | |
| if isinstance(step, SystemPromptStep) and detailed: | |
| logger.log_markdown(title="System prompt", content=step.system_prompt, level=LogLevel.ERROR) | |
| elif isinstance(step, TaskStep): | |
| logger.log_task(step.task, "", level=LogLevel.ERROR) | |
| elif isinstance(step, ActionStep): | |
| logger.log_rule(f"Step {step.step_number}", level=LogLevel.ERROR) | |
| if detailed: | |
| logger.log_messages(step.model_input_messages) | |
| logger.log_markdown(title="Agent output:", content=step.model_output, level=LogLevel.ERROR) | |
| elif isinstance(step, PlanningStep): | |
| logger.log_rule("Planning step", level=LogLevel.ERROR) | |
| if detailed: | |
| logger.log_messages(step.model_input_messages, level=LogLevel.ERROR) | |
| logger.log_markdown(title="Agent output:", content=step.facts + "\n" + step.plan, level=LogLevel.ERROR) | |
| __all__ = ["AgentMemory"] | |