|
|
from __future__ import annotations |
|
|
|
|
|
import logging |
|
|
from typing import Dict, List, Optional |
|
|
|
|
|
from langchain_core.messages import ( |
|
|
AIMessage, |
|
|
BaseMessage, |
|
|
HumanMessage, |
|
|
SystemMessage, |
|
|
ToolMessage, |
|
|
) |
|
|
from pydantic import BaseModel |
|
|
|
|
|
from browser_use.agent.message_manager.views import MessageMetadata |
|
|
from browser_use.agent.prompts import AgentMessagePrompt |
|
|
from browser_use.agent.views import ActionResult, AgentOutput, AgentStepInfo, MessageManagerState |
|
|
from browser_use.browser.views import BrowserState |
|
|
from browser_use.utils import time_execution_sync |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class MessageManagerSettings(BaseModel): |
|
|
max_input_tokens: int = 128000 |
|
|
estimated_characters_per_token: int = 3 |
|
|
image_tokens: int = 800 |
|
|
include_attributes: list[str] = [] |
|
|
message_context: Optional[str] = None |
|
|
sensitive_data: Optional[Dict[str, str]] = None |
|
|
available_file_paths: Optional[List[str]] = None |
|
|
|
|
|
|
|
|
class MessageManager: |
|
|
def __init__( |
|
|
self, |
|
|
task: str, |
|
|
system_message: SystemMessage, |
|
|
settings: MessageManagerSettings = MessageManagerSettings(), |
|
|
state: MessageManagerState = MessageManagerState(), |
|
|
): |
|
|
self.task = task |
|
|
self.settings = settings |
|
|
self.state = state |
|
|
self.system_prompt = system_message |
|
|
|
|
|
|
|
|
if len(self.state.history.messages) == 0: |
|
|
self._init_messages() |
|
|
|
|
|
def _init_messages(self) -> None: |
|
|
"""Initialize the message history with system message, context, task, and other initial messages""" |
|
|
self._add_message_with_tokens(self.system_prompt) |
|
|
|
|
|
if self.settings.message_context: |
|
|
context_message = HumanMessage(content='Context for the task' + self.settings.message_context) |
|
|
self._add_message_with_tokens(context_message) |
|
|
|
|
|
task_message = HumanMessage( |
|
|
content=f'Your ultimate task is: """{self.task}""". If you achieved your ultimate task, stop everything and use the done action in the next step to complete the task. If not, continue as usual.' |
|
|
) |
|
|
self._add_message_with_tokens(task_message) |
|
|
|
|
|
if self.settings.sensitive_data: |
|
|
info = f'Here are placeholders for sensitve data: {list(self.settings.sensitive_data.keys())}' |
|
|
info += 'To use them, write <secret>the placeholder name</secret>' |
|
|
info_message = HumanMessage(content=info) |
|
|
self._add_message_with_tokens(info_message) |
|
|
|
|
|
placeholder_message = HumanMessage(content='Example output:') |
|
|
self._add_message_with_tokens(placeholder_message) |
|
|
|
|
|
tool_calls = [ |
|
|
{ |
|
|
'name': 'AgentOutput', |
|
|
'args': { |
|
|
'current_state': { |
|
|
'evaluation_previous_goal': 'Success - I opend the first page', |
|
|
'memory': 'Starting with the new task. I have completed 1/10 steps', |
|
|
'next_goal': 'Click on company a', |
|
|
}, |
|
|
'action': [{'click_element': {'index': 0}}], |
|
|
}, |
|
|
'id': str(self.state.tool_id), |
|
|
'type': 'tool_call', |
|
|
} |
|
|
] |
|
|
|
|
|
example_tool_call = AIMessage( |
|
|
content='', |
|
|
tool_calls=tool_calls, |
|
|
) |
|
|
self._add_message_with_tokens(example_tool_call) |
|
|
self.add_tool_message(content='Browser started') |
|
|
|
|
|
placeholder_message = HumanMessage(content='[Your task history memory starts here]') |
|
|
self._add_message_with_tokens(placeholder_message) |
|
|
|
|
|
if self.settings.available_file_paths: |
|
|
filepaths_msg = HumanMessage(content=f'Here are file paths you can use: {self.settings.available_file_paths}') |
|
|
self._add_message_with_tokens(filepaths_msg) |
|
|
|
|
|
def add_new_task(self, new_task: str) -> None: |
|
|
content = f'Your new ultimate task is: """{new_task}""". Take the previous context into account and finish your new ultimate task. ' |
|
|
msg = HumanMessage(content=content) |
|
|
self._add_message_with_tokens(msg) |
|
|
self.task = new_task |
|
|
|
|
|
@time_execution_sync('--add_state_message') |
|
|
def add_state_message( |
|
|
self, |
|
|
state: BrowserState, |
|
|
result: Optional[List[ActionResult]] = None, |
|
|
step_info: Optional[AgentStepInfo] = None, |
|
|
use_vision=True, |
|
|
) -> None: |
|
|
"""Add browser state as human message""" |
|
|
|
|
|
|
|
|
if result: |
|
|
for r in result: |
|
|
if r.include_in_memory: |
|
|
if r.extracted_content: |
|
|
msg = HumanMessage(content='Action result: ' + str(r.extracted_content)) |
|
|
self._add_message_with_tokens(msg) |
|
|
if r.error: |
|
|
|
|
|
if r.error.endswith('\n'): |
|
|
r.error = r.error[:-1] |
|
|
|
|
|
last_line = r.error.split('\n')[-1] |
|
|
msg = HumanMessage(content='Action error: ' + last_line) |
|
|
self._add_message_with_tokens(msg) |
|
|
result = None |
|
|
|
|
|
|
|
|
state_message = AgentMessagePrompt( |
|
|
state, |
|
|
result, |
|
|
include_attributes=self.settings.include_attributes, |
|
|
step_info=step_info, |
|
|
).get_user_message(use_vision) |
|
|
self._add_message_with_tokens(state_message) |
|
|
|
|
|
def add_model_output(self, model_output: AgentOutput) -> None: |
|
|
"""Add model output as AI message""" |
|
|
tool_calls = [ |
|
|
{ |
|
|
'name': 'AgentOutput', |
|
|
'args': model_output.model_dump(mode='json', exclude_unset=True), |
|
|
'id': str(self.state.tool_id), |
|
|
'type': 'tool_call', |
|
|
} |
|
|
] |
|
|
|
|
|
msg = AIMessage( |
|
|
content='', |
|
|
tool_calls=tool_calls, |
|
|
) |
|
|
|
|
|
self._add_message_with_tokens(msg) |
|
|
|
|
|
self.add_tool_message(content='') |
|
|
|
|
|
def add_plan(self, plan: Optional[str], position: int | None = None) -> None: |
|
|
if plan: |
|
|
msg = AIMessage(content=plan) |
|
|
self._add_message_with_tokens(msg, position) |
|
|
|
|
|
@time_execution_sync('--get_messages') |
|
|
def get_messages(self) -> List[BaseMessage]: |
|
|
"""Get current message list, potentially trimmed to max tokens""" |
|
|
|
|
|
msg = [m.message for m in self.state.history.messages] |
|
|
|
|
|
total_input_tokens = 0 |
|
|
logger.debug(f'Messages in history: {len(self.state.history.messages)}:') |
|
|
for m in self.state.history.messages: |
|
|
total_input_tokens += m.metadata.tokens |
|
|
logger.debug(f'{m.message.__class__.__name__} - Token count: {m.metadata.tokens}') |
|
|
logger.debug(f'Total input tokens: {total_input_tokens}') |
|
|
|
|
|
return msg |
|
|
|
|
|
def _add_message_with_tokens(self, message: BaseMessage, position: int | None = None) -> None: |
|
|
"""Add message with token count metadata |
|
|
position: None for last, -1 for second last, etc. |
|
|
""" |
|
|
|
|
|
|
|
|
if self.settings.sensitive_data: |
|
|
message = self._filter_sensitive_data(message) |
|
|
|
|
|
token_count = self._count_tokens(message) |
|
|
metadata = MessageMetadata(tokens=token_count) |
|
|
self.state.history.add_message(message, metadata, position) |
|
|
|
|
|
@time_execution_sync('--filter_sensitive_data') |
|
|
def _filter_sensitive_data(self, message: BaseMessage) -> BaseMessage: |
|
|
"""Filter out sensitive data from the message""" |
|
|
|
|
|
def replace_sensitive(value: str) -> str: |
|
|
if not self.settings.sensitive_data: |
|
|
return value |
|
|
for key, val in self.settings.sensitive_data.items(): |
|
|
if not val: |
|
|
continue |
|
|
value = value.replace(val, f'<secret>{key}</secret>') |
|
|
return value |
|
|
|
|
|
if isinstance(message.content, str): |
|
|
message.content = replace_sensitive(message.content) |
|
|
elif isinstance(message.content, list): |
|
|
for i, item in enumerate(message.content): |
|
|
if isinstance(item, dict) and 'text' in item: |
|
|
item['text'] = replace_sensitive(item['text']) |
|
|
message.content[i] = item |
|
|
return message |
|
|
|
|
|
def _count_tokens(self, message: BaseMessage) -> int: |
|
|
"""Count tokens in a message using the model's tokenizer""" |
|
|
tokens = 0 |
|
|
if isinstance(message.content, list): |
|
|
for item in message.content: |
|
|
if 'image_url' in item: |
|
|
tokens += self.settings.image_tokens |
|
|
elif isinstance(item, dict) and 'text' in item: |
|
|
tokens += self._count_text_tokens(item['text']) |
|
|
else: |
|
|
msg = message.content |
|
|
if hasattr(message, 'tool_calls'): |
|
|
msg += str(message.tool_calls) |
|
|
tokens += self._count_text_tokens(msg) |
|
|
return tokens |
|
|
|
|
|
def _count_text_tokens(self, text: str) -> int: |
|
|
"""Count tokens in a text string""" |
|
|
tokens = len(text) // self.settings.estimated_characters_per_token |
|
|
return tokens |
|
|
|
|
|
def cut_messages(self): |
|
|
"""Get current message list, potentially trimmed to max tokens""" |
|
|
diff = self.state.history.current_tokens - self.settings.max_input_tokens |
|
|
if diff <= 0: |
|
|
return None |
|
|
|
|
|
msg = self.state.history.messages[-1] |
|
|
|
|
|
|
|
|
if isinstance(msg.message.content, list): |
|
|
text = '' |
|
|
for item in msg.message.content: |
|
|
if 'image_url' in item: |
|
|
msg.message.content.remove(item) |
|
|
diff -= self.settings.image_tokens |
|
|
msg.metadata.tokens -= self.settings.image_tokens |
|
|
self.state.history.current_tokens -= self.settings.image_tokens |
|
|
logger.debug( |
|
|
f'Removed image with {self.settings.image_tokens} tokens - total tokens now: {self.state.history.current_tokens}/{self.settings.max_input_tokens}' |
|
|
) |
|
|
elif 'text' in item and isinstance(item, dict): |
|
|
text += item['text'] |
|
|
msg.message.content = text |
|
|
self.state.history.messages[-1] = msg |
|
|
|
|
|
if diff <= 0: |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
proportion_to_remove = diff / msg.metadata.tokens |
|
|
if proportion_to_remove > 0.99: |
|
|
raise ValueError( |
|
|
f'Max token limit reached - history is too long - reduce the system prompt or task. ' |
|
|
f'proportion_to_remove: {proportion_to_remove}' |
|
|
) |
|
|
logger.debug( |
|
|
f'Removing {proportion_to_remove * 100:.2f}% of the last message {proportion_to_remove * msg.metadata.tokens:.2f} / {msg.metadata.tokens:.2f} tokens)' |
|
|
) |
|
|
|
|
|
content = msg.message.content |
|
|
characters_to_remove = int(len(content) * proportion_to_remove) |
|
|
content = content[:-characters_to_remove] |
|
|
|
|
|
|
|
|
self.state.history.remove_last_state_message() |
|
|
|
|
|
|
|
|
msg = HumanMessage(content=content) |
|
|
self._add_message_with_tokens(msg) |
|
|
|
|
|
last_msg = self.state.history.messages[-1] |
|
|
|
|
|
logger.debug( |
|
|
f'Added message with {last_msg.metadata.tokens} tokens - total tokens now: {self.state.history.current_tokens}/{self.settings.max_input_tokens} - total messages: {len(self.state.history.messages)}' |
|
|
) |
|
|
|
|
|
def _remove_last_state_message(self) -> None: |
|
|
"""Remove last state message from history""" |
|
|
self.state.history.remove_last_state_message() |
|
|
|
|
|
def add_tool_message(self, content: str) -> None: |
|
|
"""Add tool message to history""" |
|
|
msg = ToolMessage(content=content, tool_call_id=str(self.state.tool_id)) |
|
|
self.state.tool_id += 1 |
|
|
self._add_message_with_tokens(msg) |
|
|
|