Josedcape commited on
Commit
e3afff6
·
verified ·
1 Parent(s): 1620ca9

Update src/agent/custom_agent.py

Browse files
Files changed (1) hide show
  1. src/agent/custom_agent.py +147 -519
src/agent/custom_agent.py CHANGED
@@ -1,519 +1,147 @@
1
- import json
2
- import logging
3
- import pdb
4
- import traceback
5
- from typing import Optional, Type
6
- from PIL import Image, ImageDraw, ImageFont
7
- import os
8
- import base64
9
- import io
10
-
11
- from browser_use.agent.prompts import SystemPrompt
12
- from browser_use.agent.service import Agent
13
- from browser_use.agent.views import (
14
- ActionResult,
15
- AgentHistoryList,
16
- AgentOutput,
17
- AgentHistory,
18
- )
19
- from browser_use.browser.browser import Browser
20
- from browser_use.browser.context import BrowserContext
21
- from browser_use.browser.views import BrowserStateHistory
22
- from browser_use.controller.service import Controller
23
- from browser_use.telemetry.views import (
24
- AgentEndTelemetryEvent,
25
- AgentRunTelemetryEvent,
26
- AgentStepErrorTelemetryEvent,
27
- )
28
- from browser_use.utils import time_execution_async
29
- from langchain_core.language_models.chat_models import BaseChatModel
30
- from langchain_core.messages import (
31
- BaseMessage,
32
- )
33
- from src.utils.agent_state import AgentState
34
-
35
- from .custom_massage_manager import CustomMassageManager
36
- from .custom_views import CustomAgentOutput, CustomAgentStepInfo
37
-
38
- logger = logging.getLogger(__name__)
39
-
40
-
41
- class CustomAgent(Agent):
42
- def __init__(
43
- self,
44
- task: str,
45
- llm: BaseChatModel,
46
- add_infos: str = "",
47
- browser: Browser | None = None,
48
- browser_context: BrowserContext | None = None,
49
- controller: Controller = Controller(),
50
- use_vision: bool = True,
51
- save_conversation_path: Optional[str] = None,
52
- max_failures: int = 5,
53
- retry_delay: int = 10,
54
- system_prompt_class: Type[SystemPrompt] = SystemPrompt,
55
- max_input_tokens: int = 128000,
56
- validate_output: bool = False,
57
- include_attributes: list[str] = [
58
- "title",
59
- "type",
60
- "name",
61
- "role",
62
- "tabindex",
63
- "aria-label",
64
- "placeholder",
65
- "value",
66
- "alt",
67
- "aria-expanded",
68
- ],
69
- max_error_length: int = 400,
70
- max_actions_per_step: int = 10,
71
- tool_call_in_content: bool = True,
72
- agent_state: AgentState = None,
73
- ):
74
- super().__init__(
75
- task=task,
76
- llm=llm,
77
- browser=browser,
78
- browser_context=browser_context,
79
- controller=controller,
80
- use_vision=use_vision,
81
- save_conversation_path=save_conversation_path,
82
- max_failures=max_failures,
83
- retry_delay=retry_delay,
84
- system_prompt_class=system_prompt_class,
85
- max_input_tokens=max_input_tokens,
86
- validate_output=validate_output,
87
- include_attributes=include_attributes,
88
- max_error_length=max_error_length,
89
- max_actions_per_step=max_actions_per_step,
90
- tool_call_in_content=tool_call_in_content,
91
- )
92
- if hasattr(self.llm, 'model_name') and self.llm.model_name in ["deepseek-reasoner"]:
93
- # deepseek-reasoner does not support function calling
94
- self.use_function_calling = False
95
- # TODO: deepseek-reasoner only support 64000 context
96
- self.max_input_tokens = 64000
97
- else:
98
- self.use_function_calling = True
99
- self.add_infos = add_infos
100
- self.agent_state = agent_state
101
- self.message_manager = CustomMassageManager(
102
- llm=self.llm,
103
- task=self.task,
104
- action_descriptions=self.controller.registry.get_prompt_description(),
105
- system_prompt_class=self.system_prompt_class,
106
- max_input_tokens=self.max_input_tokens,
107
- include_attributes=self.include_attributes,
108
- max_error_length=self.max_error_length,
109
- max_actions_per_step=self.max_actions_per_step,
110
- tool_call_in_content=tool_call_in_content,
111
- use_function_calling=self.use_function_calling
112
- )
113
-
114
- def _setup_action_models(self) -> None:
115
- """Setup dynamic action models from controller's registry"""
116
- # Get the dynamic action model from controller's registry
117
- self.ActionModel = self.controller.registry.create_action_model()
118
- # Create output model with the dynamic actions
119
- self.AgentOutput = CustomAgentOutput.type_with_custom_actions(self.ActionModel)
120
-
121
- def _log_response(self, response: CustomAgentOutput) -> None:
122
- """Log the model's response"""
123
- if "Success" in response.current_state.prev_action_evaluation:
124
- emoji = "✅"
125
- elif "Failed" in response.current_state.prev_action_evaluation:
126
- emoji = "❌"
127
- else:
128
- emoji = "🤷"
129
-
130
- logger.info(f"{emoji} Eval: {response.current_state.prev_action_evaluation}")
131
- logger.info(f"🧠 New Memory: {response.current_state.important_contents}")
132
- logger.info(f"⏳ Task Progress: \n{response.current_state.task_progress}")
133
- logger.info(f"📋 Future Plans: \n{response.current_state.future_plans}")
134
- logger.info(f"🤔 Thought: {response.current_state.thought}")
135
- logger.info(f"🎯 Summary: {response.current_state.summary}")
136
- for i, action in enumerate(response.action):
137
- logger.info(
138
- f"🛠️ Action {i + 1}/{len(response.action)}: {action.model_dump_json(exclude_unset=True)}"
139
- )
140
-
141
- def update_step_info(
142
- self, model_output: CustomAgentOutput, step_info: CustomAgentStepInfo = None
143
- ):
144
- """
145
- update step info
146
- """
147
- if step_info is None:
148
- return
149
-
150
- step_info.step_number += 1
151
- important_contents = model_output.current_state.important_contents
152
- if (
153
- important_contents
154
- and "None" not in important_contents
155
- and important_contents not in step_info.memory
156
- ):
157
- step_info.memory += important_contents + "\n"
158
-
159
- task_progress = model_output.current_state.task_progress
160
- if task_progress and "None" not in task_progress:
161
- step_info.task_progress = task_progress
162
-
163
- future_plans = model_output.current_state.future_plans
164
- if future_plans and "None" not in future_plans:
165
- step_info.future_plans = future_plans
166
-
167
- @time_execution_async("--get_next_action")
168
- async def get_next_action(self, input_messages: list[BaseMessage]) -> AgentOutput:
169
- """Get next action from LLM based on current state"""
170
- if self.use_function_calling:
171
- try:
172
- structured_llm = self.llm.with_structured_output(self.AgentOutput, include_raw=True)
173
- response: dict[str, Any] = await structured_llm.ainvoke(input_messages) # type: ignore
174
-
175
- parsed: AgentOutput = response['parsed']
176
- # cut the number of actions to max_actions_per_step
177
- parsed.action = parsed.action[: self.max_actions_per_step]
178
- self._log_response(parsed)
179
- self.n_steps += 1
180
-
181
- return parsed
182
- except Exception as e:
183
- # If something goes wrong, try to invoke the LLM again without structured output,
184
- # and Manually parse the response. Temporarily solution for DeepSeek
185
- ret = self.llm.invoke(input_messages)
186
- if isinstance(ret.content, list):
187
- parsed_json = json.loads(ret.content[0].replace("```json", "").replace("```", ""))
188
- else:
189
- parsed_json = json.loads(ret.content.replace("```json", "").replace("```", ""))
190
- parsed: AgentOutput = self.AgentOutput(**parsed_json)
191
- if parsed is None:
192
- raise ValueError(f'Could not parse response.')
193
-
194
- # cut the number of actions to max_actions_per_step
195
- parsed.action = parsed.action[: self.max_actions_per_step]
196
- self._log_response(parsed)
197
- self.n_steps += 1
198
-
199
- return parsed
200
- else:
201
- ret = self.llm.invoke(input_messages)
202
- if not self.use_function_calling:
203
- self.message_manager._add_message_with_tokens(ret)
204
- logger.info(f"🤯 Start Deep Thinking: ")
205
- logger.info(ret.reasoning_content)
206
- logger.info(f"🤯 End Deep Thinking")
207
- if isinstance(ret.content, list):
208
- parsed_json = json.loads(ret.content[0].replace("```json", "").replace("```", ""))
209
- else:
210
- parsed_json = json.loads(ret.content.replace("```json", "").replace("```", ""))
211
- parsed: AgentOutput = self.AgentOutput(**parsed_json)
212
- if parsed is None:
213
- raise ValueError(f'Could not parse response.')
214
-
215
- # cut the number of actions to max_actions_per_step
216
- parsed.action = parsed.action[: self.max_actions_per_step]
217
- self._log_response(parsed)
218
- self.n_steps += 1
219
-
220
- return parsed
221
-
222
- @time_execution_async("--step")
223
- async def step(self, step_info: Optional[CustomAgentStepInfo] = None) -> None:
224
- """Execute one step of the task"""
225
- logger.info(f"\n📍 Step {self.n_steps}")
226
- state = None
227
- model_output = None
228
- result: list[ActionResult] = []
229
-
230
- try:
231
- state = await self.browser_context.get_state(use_vision=self.use_vision)
232
- self.message_manager.add_state_message(state, self._last_result, step_info)
233
- input_messages = self.message_manager.get_messages()
234
- model_output = await self.get_next_action(input_messages)
235
- self.update_step_info(model_output, step_info)
236
- logger.info(f"🧠 All Memory: \n{step_info.memory}")
237
- self._save_conversation(input_messages, model_output)
238
- if self.use_function_calling:
239
- self.message_manager._remove_last_state_message() # we dont want the whole state in the chat history
240
- self.message_manager.add_model_output(model_output)
241
-
242
- result: list[ActionResult] = await self.controller.multi_act(
243
- model_output.action, self.browser_context
244
- )
245
- if len(result) != len(model_output.action):
246
- # I think something changes, such information should let LLM know
247
- for ri in range(len(result), len(model_output.action)):
248
- result.append(ActionResult(extracted_content=None,
249
- include_in_memory=True,
250
- error=f"{model_output.action[ri].model_dump_json(exclude_unset=True)} is Failed to execute. \
251
- Something new appeared after action {model_output.action[len(result) - 1].model_dump_json(exclude_unset=True)}",
252
- is_done=False))
253
- self._last_result = result
254
-
255
- if len(result) > 0 and result[-1].is_done:
256
- logger.info(f"📄 Result: {result[-1].extracted_content}")
257
-
258
- self.consecutive_failures = 0
259
-
260
- except Exception as e:
261
- result = self._handle_step_error(e)
262
- self._last_result = result
263
-
264
- finally:
265
- if not result:
266
- return
267
- for r in result:
268
- if r.error:
269
- self.telemetry.capture(
270
- AgentStepErrorTelemetryEvent(
271
- agent_id=self.agent_id,
272
- error=r.error,
273
- )
274
- )
275
- if state:
276
- self._make_history_item(model_output, state, result)
277
- def create_history_gif(
278
- self,
279
- output_path: str = 'agent_history.gif',
280
- duration: int = 3000,
281
- show_goals: bool = True,
282
- show_task: bool = True,
283
- show_logo: bool = False,
284
- font_size: int = 40,
285
- title_font_size: int = 56,
286
- goal_font_size: int = 44,
287
- margin: int = 40,
288
- line_spacing: float = 1.5,
289
- ) -> None:
290
- """Create a GIF from the agent's history with overlaid task and goal text."""
291
- if not self.history.history:
292
- logger.warning('No history to create GIF from')
293
- return
294
-
295
- images = []
296
- # if history is empty or first screenshot is None, we can't create a gif
297
- if not self.history.history or not self.history.history[0].state.screenshot:
298
- logger.warning('No history or first screenshot to create GIF from')
299
- return
300
-
301
- # Try to load nicer fonts
302
- try:
303
- # Try different font options in order of preference
304
- font_options = ['Helvetica', 'Arial', 'DejaVuSans', 'Verdana']
305
- font_loaded = False
306
-
307
- for font_name in font_options:
308
- try:
309
- import platform
310
- if platform.system() == "Windows":
311
- # Need to specify the abs font path on Windows
312
- font_name = os.path.join(os.getenv("WIN_FONT_DIR", "C:\\Windows\\Fonts"), font_name + ".ttf")
313
- regular_font = ImageFont.truetype(font_name, font_size)
314
- title_font = ImageFont.truetype(font_name, title_font_size)
315
- goal_font = ImageFont.truetype(font_name, goal_font_size)
316
- font_loaded = True
317
- break
318
- except OSError:
319
- continue
320
-
321
- if not font_loaded:
322
- raise OSError('No preferred fonts found')
323
-
324
- except OSError:
325
- regular_font = ImageFont.load_default()
326
- title_font = ImageFont.load_default()
327
-
328
- goal_font = regular_font
329
-
330
- # Load logo if requested
331
- logo = None
332
- if show_logo:
333
- try:
334
- logo = Image.open('./static/browser-use.png')
335
- # Resize logo to be small (e.g., 40px height)
336
- logo_height = 150
337
- aspect_ratio = logo.width / logo.height
338
- logo_width = int(logo_height * aspect_ratio)
339
- logo = logo.resize((logo_width, logo_height), Image.Resampling.LANCZOS)
340
- except Exception as e:
341
- logger.warning(f'Could not load logo: {e}')
342
-
343
- # Create task frame if requested
344
- if show_task and self.task:
345
- task_frame = self._create_task_frame(
346
- self.task,
347
- self.history.history[0].state.screenshot,
348
- title_font,
349
- regular_font,
350
- logo,
351
- line_spacing,
352
- )
353
- images.append(task_frame)
354
-
355
- # Process each history item
356
- for i, item in enumerate(self.history.history, 1):
357
- if not item.state.screenshot:
358
- continue
359
-
360
- # Convert base64 screenshot to PIL Image
361
- img_data = base64.b64decode(item.state.screenshot)
362
- image = Image.open(io.BytesIO(img_data))
363
-
364
- if show_goals and item.model_output:
365
- image = self._add_overlay_to_image(
366
- image=image,
367
- step_number=i,
368
- goal_text=item.model_output.current_state.thought,
369
- regular_font=regular_font,
370
- title_font=title_font,
371
- margin=margin,
372
- logo=logo,
373
- )
374
-
375
- images.append(image)
376
-
377
- if images:
378
- # Save the GIF
379
- images[0].save(
380
- output_path,
381
- save_all=True,
382
- append_images=images[1:],
383
- duration=duration,
384
- loop=0,
385
- optimize=False,
386
- )
387
- logger.info(f'Created GIF at {output_path}')
388
- else:
389
- logger.warning('No images found in history to create GIF')
390
-
391
- async def run(self, max_steps: int = 100) -> AgentHistoryList:
392
- """Execute the task with maximum number of steps"""
393
- try:
394
- logger.info(f"🚀 Starting task: {self.task}")
395
-
396
- self.telemetry.capture(
397
- AgentRunTelemetryEvent(
398
- agent_id=self.agent_id,
399
- task=self.task,
400
- )
401
- )
402
-
403
- step_info = CustomAgentStepInfo(
404
- task=self.task,
405
- add_infos=self.add_infos,
406
- step_number=1,
407
- max_steps=max_steps,
408
- memory="",
409
- task_progress="",
410
- future_plans=""
411
- )
412
-
413
- for step in range(max_steps):
414
- # 1) Check if stop requested
415
- if self.agent_state and self.agent_state.is_stop_requested():
416
- logger.info("🛑 Stop requested by user")
417
- self._create_stop_history_item()
418
- break
419
-
420
- # 2) Store last valid state before step
421
- if self.browser_context and self.agent_state:
422
- state = await self.browser_context.get_state(use_vision=self.use_vision)
423
- self.agent_state.set_last_valid_state(state)
424
-
425
- if self._too_many_failures():
426
- break
427
-
428
- # 3) Do the step
429
- await self.step(step_info)
430
-
431
- if self.history.is_done():
432
- if (
433
- self.validate_output and step < max_steps - 1
434
- ): # if last step, we dont need to validate
435
- if not await self._validate_output():
436
- continue
437
-
438
- logger.info("✅ Task completed successfully")
439
- break
440
- else:
441
- logger.info("❌ Failed to complete task in maximum steps")
442
-
443
- return self.history
444
-
445
- finally:
446
- self.telemetry.capture(
447
- AgentEndTelemetryEvent(
448
- agent_id=self.agent_id,
449
- task=self.task,
450
- success=self.history.is_done(),
451
- steps=len(self.history.history),
452
- )
453
- )
454
- if not self.injected_browser_context:
455
- await self.browser_context.close()
456
-
457
- if not self.injected_browser and self.browser:
458
- await self.browser.close()
459
-
460
- if self.generate_gif:
461
- self.create_history_gif()
462
-
463
- def _create_stop_history_item(self):
464
- """Create a history item for when the agent is stopped."""
465
- try:
466
- # Attempt to retrieve the last valid state from agent_state
467
- state = None
468
- if self.agent_state:
469
- last_state = self.agent_state.get_last_valid_state()
470
- if last_state:
471
- # Convert to BrowserStateHistory
472
- state = BrowserStateHistory(
473
- url=getattr(last_state, 'url', ""),
474
- title=getattr(last_state, 'title', ""),
475
- tabs=getattr(last_state, 'tabs', []),
476
- interacted_element=[None],
477
- screenshot=getattr(last_state, 'screenshot', None)
478
- )
479
- else:
480
- state = self._create_empty_state()
481
- else:
482
- state = self._create_empty_state()
483
-
484
- # Create a final item in the agent history indicating done
485
- stop_history = AgentHistory(
486
- model_output=None,
487
- state=state,
488
- result=[ActionResult(extracted_content=None, error=None, is_done=True)]
489
- )
490
- self.history.history.append(stop_history)
491
-
492
- except Exception as e:
493
- logger.error(f"Error creating stop history item: {e}")
494
- # Create empty state as fallback
495
- state = self._create_empty_state()
496
- stop_history = AgentHistory(
497
- model_output=None,
498
- state=state,
499
- result=[ActionResult(extracted_content=None, error=None, is_done=True)]
500
- )
501
- self.history.history.append(stop_history)
502
-
503
- def _convert_to_browser_state_history(self, browser_state):
504
- return BrowserStateHistory(
505
- url=getattr(browser_state, 'url', ""),
506
- title=getattr(browser_state, 'title', ""),
507
- tabs=getattr(browser_state, 'tabs', []),
508
- interacted_element=[None],
509
- screenshot=getattr(browser_state, 'screenshot', None)
510
- )
511
-
512
- def _create_empty_state(self):
513
- return BrowserStateHistory(
514
- url="",
515
- title="",
516
- tabs=[],
517
- interacted_element=[None],
518
- screenshot=None
519
- )
 
1
+ import json
2
+ import logging
3
+ from typing import Optional, Type
4
+ from PIL import Image, ImageDraw, ImageFont
5
+ import os
6
+ import base64
7
+ import io
8
+
9
+ from browser_use.agent.prompts import SystemPrompt
10
+ from browser_use.agent.service import Agent
11
+ from browser_use.agent.views import (
12
+ ActionResult,
13
+ AgentHistoryList,
14
+ AgentOutput,
15
+ AgentHistory,
16
+ )
17
+ from browser_use.browser.browser import Browser
18
+ from browser_use.browser.context import BrowserContext
19
+ from browser_use.browser.views import BrowserStateHistory
20
+ from browser_use.controller.service import Controller
21
+ from browser_use.telemetry.views import (
22
+ AgentEndTelemetryEvent,
23
+ AgentRunTelemetryEvent,
24
+ AgentStepErrorTelemetryEvent,
25
+ )
26
+ from browser_use.utils import time_execution_async
27
+ from langchain_core.language_models.chat_models import BaseChatModel
28
+ from langchain_core.messages import (
29
+ BaseMessage,
30
+ )
31
+ from src.utils.agent_state import AgentState
32
+
33
+ from .custom_massage_manager import CustomMassageManager
34
+ from .custom_views import CustomAgentOutput, CustomAgentStepInfo
35
+
36
+ logger = logging.getLogger(__name__)
37
+
38
+ class CustomAgent(Agent):
39
+ def __init__(
40
+ self,
41
+ task: str,
42
+ llm: BaseChatModel,
43
+ add_infos: str = "",
44
+ browser: Browser | None = None,
45
+ browser_context: BrowserContext | None = None,
46
+ controller: Controller = Controller(),
47
+ use_vision: bool = True,
48
+ save_conversation_path: Optional[str] = None,
49
+ max_failures: int = 5,
50
+ retry_delay: int = 10,
51
+ system_prompt_class: Type[SystemPrompt] = SystemPrompt,
52
+ max_input_tokens: int = 128000,
53
+ validate_output: bool = False,
54
+ include_attributes: list[str] = [
55
+ "title",
56
+ "type",
57
+ "name",
58
+ "role",
59
+ "tabindex",
60
+ "aria-label",
61
+ "placeholder",
62
+ "value",
63
+ "alt",
64
+ "aria-expanded",
65
+ ],
66
+ max_error_length: int = 400,
67
+ max_actions_per_step: int = 10,
68
+ tool_call_in_content: bool = True,
69
+ agent_state: AgentState = None,
70
+ ):
71
+ super().__init__(
72
+ task=task,
73
+ llm=llm,
74
+ browser=browser,
75
+ browser_context=browser_context,
76
+ controller=controller,
77
+ use_vision=use_vision,
78
+ save_conversation_path=save_conversation_path,
79
+ max_failures=max_failures,
80
+ retry_delay=retry_delay,
81
+ system_prompt_class=system_prompt_class,
82
+ max_input_tokens=max_input_tokens,
83
+ validate_output=validate_output,
84
+ include_attributes=include_attributes,
85
+ max_error_length=max_error_length,
86
+ max_actions_per_step=max_actions_per_step,
87
+ tool_call_in_content=tool_call_in_content,
88
+ )
89
+ if hasattr(self.llm, 'model_name') and self.llm.model_name in ["deepseek-reasoner"]:
90
+ self.use_function_calling = False
91
+ self.max_input_tokens = 64000
92
+ else:
93
+ self.use_function_calling = True
94
+ self.add_infos = add_infos
95
+ self.agent_state = agent_state
96
+ self.message_manager = CustomMassageManager(
97
+ llm=self.llm,
98
+ task=self.task,
99
+ action_descriptions=self.controller.registry.get_prompt_description(),
100
+ system_prompt_class=self.system_prompt_class,
101
+ max_input_tokens=self.max_input_tokens,
102
+ include_attributes=self.include_attributes,
103
+ max_error_length=self.max_error_length,
104
+ max_actions_per_step=self.max_actions_per_step,
105
+ tool_call_in_content=tool_call_in_content,
106
+ use_function_calling=self.use_function_calling
107
+ )
108
+
109
+ async def get_next_action(self, input_messages: list[BaseMessage]) -> AgentOutput:
110
+ try:
111
+ structured_llm = self.llm.with_structured_output(self.AgentOutput, include_raw=True)
112
+ response: dict[str, any] = await structured_llm.ainvoke(input_messages)
113
+ parsed: AgentOutput = response['parsed']
114
+ parsed.action = parsed.action[: self.max_actions_per_step]
115
+ self._log_response(parsed)
116
+ self.n_steps += 1
117
+ return parsed
118
+ except Exception as e:
119
+ logger.error(f"Error in get_next_action: {e}")
120
+ raise
121
+
122
+ async def step(self, step_info: Optional[CustomAgentStepInfo] = None) -> None:
123
+ logger.info(f"Step {self.n_steps}")
124
+ state = None
125
+ model_output = None
126
+ result: list[ActionResult] = []
127
+
128
+ try:
129
+ state = await self.browser_context.get_state(use_vision=self.use_vision)
130
+ self.message_manager.add_state_message(state, self._last_result, step_info)
131
+ input_messages = self.message_manager.get_messages()
132
+ model_output = await self.get_next_action(input_messages)
133
+ self.update_step_info(model_output, step_info)
134
+ self._last_result = await self.controller.multi_act(model_output.action, self.browser_context)
135
+
136
+ if len(self._last_result) > 0 and self._last_result[-1].is_done:
137
+ logger.info(f"Task completed with result: {self._last_result[-1].extracted_content}")
138
+
139
+ self.consecutive_failures = 0
140
+
141
+ except Exception as e:
142
+ logger.error(f"Error in step: {e}")
143
+ self._last_result = self._handle_step_error(e)
144
+
145
+ finally:
146
+ if state:
147
+ self._make_history_item(model_output, state, self._last_result)