| import json | |
| import logging | |
| from typing import Optional, Type | |
| from PIL import Image, ImageDraw, ImageFont | |
| import os | |
| import base64 | |
| import io | |
| from browser_use.agent.prompts import SystemPrompt | |
| from browser_use.agent.service import Agent | |
| from browser_use.agent.views import ( | |
| ActionResult, | |
| AgentHistoryList, | |
| AgentOutput, | |
| AgentHistory, | |
| ) | |
| from browser_use.browser.browser import Browser | |
| from browser_use.browser.context import BrowserContext | |
| from browser_use.browser.views import BrowserStateHistory | |
| from browser_use.controller.service import Controller | |
| from browser_use.telemetry.views import ( | |
| AgentEndTelemetryEvent, | |
| AgentRunTelemetryEvent, | |
| AgentStepErrorTelemetryEvent, | |
| ) | |
| from browser_use.utils import time_execution_async | |
| from langchain_core.language_models.chat_models import BaseChatModel | |
| from langchain_core.messages import ( | |
| BaseMessage, | |
| ) | |
| from src.utils.agent_state import AgentState | |
| from .custom_massage_manager import CustomMassageManager | |
| from .custom_views import CustomAgentOutput, CustomAgentStepInfo | |
| logger = logging.getLogger(__name__) | |
| class CustomAgent(Agent): | |
| def __init__( | |
| self, | |
| task: str, | |
| llm: BaseChatModel, | |
| add_infos: str = "", | |
| browser: Browser | None = None, | |
| browser_context: BrowserContext | None = None, | |
| controller: Controller = Controller(), | |
| use_vision: bool = True, | |
| save_conversation_path: Optional[str] = None, | |
| max_failures: int = 5, | |
| retry_delay: int = 10, | |
| system_prompt_class: Type[SystemPrompt] = SystemPrompt, | |
| max_input_tokens: int = 128000, | |
| validate_output: bool = False, | |
| include_attributes: list[str] = [ | |
| "title", | |
| "type", | |
| "name", | |
| "role", | |
| "tabindex", | |
| "aria-label", | |
| "placeholder", | |
| "value", | |
| "alt", | |
| "aria-expanded", | |
| ], | |
| max_error_length: int = 400, | |
| max_actions_per_step: int = 10, | |
| tool_call_in_content: bool = True, | |
| agent_state: AgentState = None, | |
| ): | |
| super().__init__( | |
| task=task, | |
| llm=llm, | |
| browser=browser, | |
| browser_context=browser_context, | |
| controller=controller, | |
| use_vision=use_vision, | |
| save_conversation_path=save_conversation_path, | |
| max_failures=max_failures, | |
| retry_delay=retry_delay, | |
| system_prompt_class=system_prompt_class, | |
| max_input_tokens=max_input_tokens, | |
| validate_output=validate_output, | |
| include_attributes=include_attributes, | |
| max_error_length=max_error_length, | |
| max_actions_per_step=max_actions_per_step, | |
| tool_call_in_content=tool_call_in_content, | |
| ) | |
| if hasattr(self.llm, 'model_name') and self.llm.model_name in ["deepseek-reasoner"]: | |
| self.use_function_calling = False | |
| self.max_input_tokens = 64000 | |
| else: | |
| self.use_function_calling = True | |
| self.add_infos = add_infos | |
| self.agent_state = agent_state | |
| self.message_manager = CustomMassageManager( | |
| llm=self.llm, | |
| task=self.task, | |
| action_descriptions=self.controller.registry.get_prompt_description(), | |
| system_prompt_class=self.system_prompt_class, | |
| max_input_tokens=self.max_input_tokens, | |
| include_attributes=self.include_attributes, | |
| max_error_length=self.max_error_length, | |
| max_actions_per_step=self.max_actions_per_step, | |
| tool_call_in_content=tool_call_in_content, | |
| use_function_calling=self.use_function_calling | |
| ) | |
| async def get_next_action(self, input_messages: list[BaseMessage]) -> AgentOutput: | |
| try: | |
| structured_llm = self.llm.with_structured_output(self.AgentOutput, include_raw=True) | |
| response: dict[str, any] = await structured_llm.ainvoke(input_messages) | |
| parsed: AgentOutput = response['parsed'] | |
| parsed.action = parsed.action[: self.max_actions_per_step] | |
| self._log_response(parsed) | |
| self.n_steps += 1 | |
| return parsed | |
| except Exception as e: | |
| logger.error(f"Error in get_next_action: {e}") | |
| raise | |
| async def step(self, step_info: Optional[CustomAgentStepInfo] = None) -> None: | |
| logger.info(f"Step {self.n_steps}") | |
| state = None | |
| model_output = None | |
| result: list[ActionResult] = [] | |
| try: | |
| state = await self.browser_context.get_state(use_vision=self.use_vision) | |
| self.message_manager.add_state_message(state, self._last_result, step_info) | |
| input_messages = self.message_manager.get_messages() | |
| model_output = await self.get_next_action(input_messages) | |
| self.update_step_info(model_output, step_info) | |
| self._last_result = await self.controller.multi_act(model_output.action, self.browser_context) | |
| if len(self._last_result) > 0 and self._last_result[-1].is_done: | |
| logger.info(f"Task completed with result: {self._last_result[-1].extracted_content}") | |
| self.consecutive_failures = 0 | |
| except Exception as e: | |
| logger.error(f"Error in step: {e}") | |
| self._last_result = self._handle_step_error(e) | |
| finally: | |
| if state: | |
| self._make_history_item(model_output, state, self._last_result) | |