import json import logging from typing import Optional, Type from PIL import Image, ImageDraw, ImageFont import os import base64 import io from browser_use.agent.prompts import SystemPrompt from browser_use.agent.service import Agent from browser_use.agent.views import ( ActionResult, AgentHistoryList, AgentOutput, AgentHistory, ) from browser_use.browser.browser import Browser from browser_use.browser.context import BrowserContext from browser_use.browser.views import BrowserStateHistory from browser_use.controller.service import Controller from browser_use.telemetry.views import ( AgentEndTelemetryEvent, AgentRunTelemetryEvent, AgentStepErrorTelemetryEvent, ) from browser_use.utils import time_execution_async from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import ( BaseMessage, ) from src.utils.agent_state import AgentState from .custom_massage_manager import CustomMassageManager from .custom_views import CustomAgentOutput, CustomAgentStepInfo logger = logging.getLogger(__name__) class CustomAgent(Agent): def __init__( self, task: str, llm: BaseChatModel, add_infos: str = "", browser: Browser | None = None, browser_context: BrowserContext | None = None, controller: Controller = Controller(), use_vision: bool = True, save_conversation_path: Optional[str] = None, max_failures: int = 5, retry_delay: int = 10, system_prompt_class: Type[SystemPrompt] = SystemPrompt, max_input_tokens: int = 128000, validate_output: bool = False, include_attributes: list[str] = [ "title", "type", "name", "role", "tabindex", "aria-label", "placeholder", "value", "alt", "aria-expanded", ], max_error_length: int = 400, max_actions_per_step: int = 10, tool_call_in_content: bool = True, agent_state: AgentState = None, ): super().__init__( task=task, llm=llm, browser=browser, browser_context=browser_context, controller=controller, use_vision=use_vision, save_conversation_path=save_conversation_path, max_failures=max_failures, retry_delay=retry_delay, system_prompt_class=system_prompt_class, max_input_tokens=max_input_tokens, validate_output=validate_output, include_attributes=include_attributes, max_error_length=max_error_length, max_actions_per_step=max_actions_per_step, tool_call_in_content=tool_call_in_content, ) if hasattr(self.llm, 'model_name') and self.llm.model_name in ["deepseek-reasoner"]: self.use_function_calling = False self.max_input_tokens = 64000 else: self.use_function_calling = True self.add_infos = add_infos self.agent_state = agent_state self.message_manager = CustomMassageManager( llm=self.llm, task=self.task, action_descriptions=self.controller.registry.get_prompt_description(), system_prompt_class=self.system_prompt_class, max_input_tokens=self.max_input_tokens, include_attributes=self.include_attributes, max_error_length=self.max_error_length, max_actions_per_step=self.max_actions_per_step, tool_call_in_content=tool_call_in_content, use_function_calling=self.use_function_calling ) async def get_next_action(self, input_messages: list[BaseMessage]) -> AgentOutput: try: structured_llm = self.llm.with_structured_output(self.AgentOutput, include_raw=True) response: dict[str, any] = await structured_llm.ainvoke(input_messages) parsed: AgentOutput = response['parsed'] parsed.action = parsed.action[: self.max_actions_per_step] self._log_response(parsed) self.n_steps += 1 return parsed except Exception as e: logger.error(f"Error in get_next_action: {e}") raise async def step(self, step_info: Optional[CustomAgentStepInfo] = None) -> None: logger.info(f"Step {self.n_steps}") state = None model_output = None result: list[ActionResult] = [] try: state = await self.browser_context.get_state(use_vision=self.use_vision) self.message_manager.add_state_message(state, self._last_result, step_info) input_messages = self.message_manager.get_messages() model_output = await self.get_next_action(input_messages) self.update_step_info(model_output, step_info) self._last_result = await self.controller.multi_act(model_output.action, self.browser_context) if len(self._last_result) > 0 and self._last_result[-1].is_done: logger.info(f"Task completed with result: {self._last_result[-1].extracted_content}") self.consecutive_failures = 0 except Exception as e: logger.error(f"Error in step: {e}") self._last_result = self._handle_step_error(e) finally: if state: self._make_history_item(model_output, state, self._last_result)