File size: 5,632 Bytes
e3afff6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 | import json
import logging
from typing import Optional, Type
from PIL import Image, ImageDraw, ImageFont
import os
import base64
import io
from browser_use.agent.prompts import SystemPrompt
from browser_use.agent.service import Agent
from browser_use.agent.views import (
ActionResult,
AgentHistoryList,
AgentOutput,
AgentHistory,
)
from browser_use.browser.browser import Browser
from browser_use.browser.context import BrowserContext
from browser_use.browser.views import BrowserStateHistory
from browser_use.controller.service import Controller
from browser_use.telemetry.views import (
AgentEndTelemetryEvent,
AgentRunTelemetryEvent,
AgentStepErrorTelemetryEvent,
)
from browser_use.utils import time_execution_async
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import (
BaseMessage,
)
from src.utils.agent_state import AgentState
from .custom_massage_manager import CustomMassageManager
from .custom_views import CustomAgentOutput, CustomAgentStepInfo
logger = logging.getLogger(__name__)
class CustomAgent(Agent):
def __init__(
self,
task: str,
llm: BaseChatModel,
add_infos: str = "",
browser: Browser | None = None,
browser_context: BrowserContext | None = None,
controller: Controller = Controller(),
use_vision: bool = True,
save_conversation_path: Optional[str] = None,
max_failures: int = 5,
retry_delay: int = 10,
system_prompt_class: Type[SystemPrompt] = SystemPrompt,
max_input_tokens: int = 128000,
validate_output: bool = False,
include_attributes: list[str] = [
"title",
"type",
"name",
"role",
"tabindex",
"aria-label",
"placeholder",
"value",
"alt",
"aria-expanded",
],
max_error_length: int = 400,
max_actions_per_step: int = 10,
tool_call_in_content: bool = True,
agent_state: AgentState = None,
):
super().__init__(
task=task,
llm=llm,
browser=browser,
browser_context=browser_context,
controller=controller,
use_vision=use_vision,
save_conversation_path=save_conversation_path,
max_failures=max_failures,
retry_delay=retry_delay,
system_prompt_class=system_prompt_class,
max_input_tokens=max_input_tokens,
validate_output=validate_output,
include_attributes=include_attributes,
max_error_length=max_error_length,
max_actions_per_step=max_actions_per_step,
tool_call_in_content=tool_call_in_content,
)
if hasattr(self.llm, 'model_name') and self.llm.model_name in ["deepseek-reasoner"]:
self.use_function_calling = False
self.max_input_tokens = 64000
else:
self.use_function_calling = True
self.add_infos = add_infos
self.agent_state = agent_state
self.message_manager = CustomMassageManager(
llm=self.llm,
task=self.task,
action_descriptions=self.controller.registry.get_prompt_description(),
system_prompt_class=self.system_prompt_class,
max_input_tokens=self.max_input_tokens,
include_attributes=self.include_attributes,
max_error_length=self.max_error_length,
max_actions_per_step=self.max_actions_per_step,
tool_call_in_content=tool_call_in_content,
use_function_calling=self.use_function_calling
)
async def get_next_action(self, input_messages: list[BaseMessage]) -> AgentOutput:
try:
structured_llm = self.llm.with_structured_output(self.AgentOutput, include_raw=True)
response: dict[str, any] = await structured_llm.ainvoke(input_messages)
parsed: AgentOutput = response['parsed']
parsed.action = parsed.action[: self.max_actions_per_step]
self._log_response(parsed)
self.n_steps += 1
return parsed
except Exception as e:
logger.error(f"Error in get_next_action: {e}")
raise
async def step(self, step_info: Optional[CustomAgentStepInfo] = None) -> None:
logger.info(f"Step {self.n_steps}")
state = None
model_output = None
result: list[ActionResult] = []
try:
state = await self.browser_context.get_state(use_vision=self.use_vision)
self.message_manager.add_state_message(state, self._last_result, step_info)
input_messages = self.message_manager.get_messages()
model_output = await self.get_next_action(input_messages)
self.update_step_info(model_output, step_info)
self._last_result = await self.controller.multi_act(model_output.action, self.browser_context)
if len(self._last_result) > 0 and self._last_result[-1].is_done:
logger.info(f"Task completed with result: {self._last_result[-1].extracted_content}")
self.consecutive_failures = 0
except Exception as e:
logger.error(f"Error in step: {e}")
self._last_result = self._handle_step_error(e)
finally:
if state:
self._make_history_item(model_output, state, self._last_result)
|