Webui / src /agent /custom_agent.py
Josedcape's picture
Update src/agent/custom_agent.py
e3afff6 verified
import json
import logging
from typing import Optional, Type
from PIL import Image, ImageDraw, ImageFont
import os
import base64
import io
from browser_use.agent.prompts import SystemPrompt
from browser_use.agent.service import Agent
from browser_use.agent.views import (
ActionResult,
AgentHistoryList,
AgentOutput,
AgentHistory,
)
from browser_use.browser.browser import Browser
from browser_use.browser.context import BrowserContext
from browser_use.browser.views import BrowserStateHistory
from browser_use.controller.service import Controller
from browser_use.telemetry.views import (
AgentEndTelemetryEvent,
AgentRunTelemetryEvent,
AgentStepErrorTelemetryEvent,
)
from browser_use.utils import time_execution_async
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import (
BaseMessage,
)
from src.utils.agent_state import AgentState
from .custom_massage_manager import CustomMassageManager
from .custom_views import CustomAgentOutput, CustomAgentStepInfo
logger = logging.getLogger(__name__)
class CustomAgent(Agent):
def __init__(
self,
task: str,
llm: BaseChatModel,
add_infos: str = "",
browser: Browser | None = None,
browser_context: BrowserContext | None = None,
controller: Controller = Controller(),
use_vision: bool = True,
save_conversation_path: Optional[str] = None,
max_failures: int = 5,
retry_delay: int = 10,
system_prompt_class: Type[SystemPrompt] = SystemPrompt,
max_input_tokens: int = 128000,
validate_output: bool = False,
include_attributes: list[str] = [
"title",
"type",
"name",
"role",
"tabindex",
"aria-label",
"placeholder",
"value",
"alt",
"aria-expanded",
],
max_error_length: int = 400,
max_actions_per_step: int = 10,
tool_call_in_content: bool = True,
agent_state: AgentState = None,
):
super().__init__(
task=task,
llm=llm,
browser=browser,
browser_context=browser_context,
controller=controller,
use_vision=use_vision,
save_conversation_path=save_conversation_path,
max_failures=max_failures,
retry_delay=retry_delay,
system_prompt_class=system_prompt_class,
max_input_tokens=max_input_tokens,
validate_output=validate_output,
include_attributes=include_attributes,
max_error_length=max_error_length,
max_actions_per_step=max_actions_per_step,
tool_call_in_content=tool_call_in_content,
)
if hasattr(self.llm, 'model_name') and self.llm.model_name in ["deepseek-reasoner"]:
self.use_function_calling = False
self.max_input_tokens = 64000
else:
self.use_function_calling = True
self.add_infos = add_infos
self.agent_state = agent_state
self.message_manager = CustomMassageManager(
llm=self.llm,
task=self.task,
action_descriptions=self.controller.registry.get_prompt_description(),
system_prompt_class=self.system_prompt_class,
max_input_tokens=self.max_input_tokens,
include_attributes=self.include_attributes,
max_error_length=self.max_error_length,
max_actions_per_step=self.max_actions_per_step,
tool_call_in_content=tool_call_in_content,
use_function_calling=self.use_function_calling
)
async def get_next_action(self, input_messages: list[BaseMessage]) -> AgentOutput:
try:
structured_llm = self.llm.with_structured_output(self.AgentOutput, include_raw=True)
response: dict[str, any] = await structured_llm.ainvoke(input_messages)
parsed: AgentOutput = response['parsed']
parsed.action = parsed.action[: self.max_actions_per_step]
self._log_response(parsed)
self.n_steps += 1
return parsed
except Exception as e:
logger.error(f"Error in get_next_action: {e}")
raise
async def step(self, step_info: Optional[CustomAgentStepInfo] = None) -> None:
logger.info(f"Step {self.n_steps}")
state = None
model_output = None
result: list[ActionResult] = []
try:
state = await self.browser_context.get_state(use_vision=self.use_vision)
self.message_manager.add_state_message(state, self._last_result, step_info)
input_messages = self.message_manager.get_messages()
model_output = await self.get_next_action(input_messages)
self.update_step_info(model_output, step_info)
self._last_result = await self.controller.multi_act(model_output.action, self.browser_context)
if len(self._last_result) > 0 and self._last_result[-1].is_done:
logger.info(f"Task completed with result: {self._last_result[-1].extracted_content}")
self.consecutive_failures = 0
except Exception as e:
logger.error(f"Error in step: {e}")
self._last_result = self._handle_step_error(e)
finally:
if state:
self._make_history_item(model_output, state, self._last_result)