|
|
import asyncio |
|
|
from inspect import iscoroutinefunction, signature |
|
|
from typing import Any, Callable, Dict, Generic, Optional, Type, TypeVar |
|
|
|
|
|
from langchain_core.language_models.chat_models import BaseChatModel |
|
|
from pydantic import BaseModel, Field, create_model |
|
|
|
|
|
from browser_use.browser.context import BrowserContext |
|
|
from browser_use.controller.registry.views import ( |
|
|
ActionModel, |
|
|
ActionRegistry, |
|
|
RegisteredAction, |
|
|
) |
|
|
from browser_use.telemetry.service import ProductTelemetry |
|
|
from browser_use.telemetry.views import ( |
|
|
ControllerRegisteredFunctionsTelemetryEvent, |
|
|
RegisteredFunction, |
|
|
) |
|
|
from browser_use.utils import time_execution_async, time_execution_sync |
|
|
|
|
|
Context = TypeVar('Context') |
|
|
|
|
|
|
|
|
class Registry(Generic[Context]): |
|
|
"""Service for registering and managing actions""" |
|
|
|
|
|
def __init__(self, exclude_actions: list[str] | None = None): |
|
|
self.registry = ActionRegistry() |
|
|
self.telemetry = ProductTelemetry() |
|
|
self.exclude_actions = exclude_actions if exclude_actions is not None else [] |
|
|
|
|
|
@time_execution_sync('--create_param_model') |
|
|
def _create_param_model(self, function: Callable) -> Type[BaseModel]: |
|
|
"""Creates a Pydantic model from function signature""" |
|
|
sig = signature(function) |
|
|
params = { |
|
|
name: (param.annotation, ... if param.default == param.empty else param.default) |
|
|
for name, param in sig.parameters.items() |
|
|
if name != 'browser' and name != 'page_extraction_llm' and name != 'available_file_paths' |
|
|
} |
|
|
|
|
|
return create_model( |
|
|
f'{function.__name__}_parameters', |
|
|
__base__=ActionModel, |
|
|
**params, |
|
|
) |
|
|
|
|
|
def action( |
|
|
self, |
|
|
description: str, |
|
|
param_model: Optional[Type[BaseModel]] = None, |
|
|
): |
|
|
"""Decorator for registering actions""" |
|
|
|
|
|
def decorator(func: Callable): |
|
|
|
|
|
if func.__name__ in self.exclude_actions: |
|
|
return func |
|
|
|
|
|
|
|
|
actual_param_model = param_model or self._create_param_model(func) |
|
|
|
|
|
|
|
|
if not iscoroutinefunction(func): |
|
|
|
|
|
async def async_wrapper(*args, **kwargs): |
|
|
return await asyncio.to_thread(func, *args, **kwargs) |
|
|
|
|
|
|
|
|
async_wrapper.__signature__ = signature(func) |
|
|
async_wrapper.__name__ = func.__name__ |
|
|
async_wrapper.__annotations__ = func.__annotations__ |
|
|
wrapped_func = async_wrapper |
|
|
else: |
|
|
wrapped_func = func |
|
|
|
|
|
action = RegisteredAction( |
|
|
name=func.__name__, |
|
|
description=description, |
|
|
function=wrapped_func, |
|
|
param_model=actual_param_model, |
|
|
) |
|
|
self.registry.actions[func.__name__] = action |
|
|
return func |
|
|
|
|
|
return decorator |
|
|
|
|
|
@time_execution_async('--execute_action') |
|
|
async def execute_action( |
|
|
self, |
|
|
action_name: str, |
|
|
params: dict, |
|
|
browser: Optional[BrowserContext] = None, |
|
|
page_extraction_llm: Optional[BaseChatModel] = None, |
|
|
sensitive_data: Optional[Dict[str, str]] = None, |
|
|
available_file_paths: Optional[list[str]] = None, |
|
|
|
|
|
context: Context | None = None, |
|
|
) -> Any: |
|
|
"""Execute a registered action""" |
|
|
if action_name not in self.registry.actions: |
|
|
raise ValueError(f'Action {action_name} not found') |
|
|
|
|
|
action = self.registry.actions[action_name] |
|
|
try: |
|
|
|
|
|
validated_params = action.param_model(**params) |
|
|
|
|
|
|
|
|
sig = signature(action.function) |
|
|
parameters = list(sig.parameters.values()) |
|
|
is_pydantic = parameters and issubclass(parameters[0].annotation, BaseModel) |
|
|
parameter_names = [param.name for param in parameters] |
|
|
|
|
|
if sensitive_data: |
|
|
validated_params = self._replace_sensitive_data(validated_params, sensitive_data) |
|
|
|
|
|
|
|
|
if 'browser' in parameter_names and not browser: |
|
|
raise ValueError(f'Action {action_name} requires browser but none provided.') |
|
|
if 'page_extraction_llm' in parameter_names and not page_extraction_llm: |
|
|
raise ValueError(f'Action {action_name} requires page_extraction_llm but none provided.') |
|
|
if 'available_file_paths' in parameter_names and not available_file_paths: |
|
|
raise ValueError(f'Action {action_name} requires available_file_paths but none provided.') |
|
|
|
|
|
if 'context' in parameter_names and not context: |
|
|
raise ValueError(f'Action {action_name} requires context but none provided.') |
|
|
|
|
|
|
|
|
extra_args = {} |
|
|
if 'context' in parameter_names: |
|
|
extra_args['context'] = context |
|
|
if 'browser' in parameter_names: |
|
|
extra_args['browser'] = browser |
|
|
if 'page_extraction_llm' in parameter_names: |
|
|
extra_args['page_extraction_llm'] = page_extraction_llm |
|
|
if 'available_file_paths' in parameter_names: |
|
|
extra_args['available_file_paths'] = available_file_paths |
|
|
if action_name == 'input_text' and sensitive_data: |
|
|
extra_args['has_sensitive_data'] = True |
|
|
if is_pydantic: |
|
|
return await action.function(validated_params, **extra_args) |
|
|
return await action.function(**validated_params.model_dump(), **extra_args) |
|
|
|
|
|
except Exception as e: |
|
|
raise RuntimeError(f'Error executing action {action_name}: {str(e)}') from e |
|
|
|
|
|
def _replace_sensitive_data(self, params: BaseModel, sensitive_data: Dict[str, str]) -> BaseModel: |
|
|
"""Replaces the sensitive data in the params""" |
|
|
|
|
|
|
|
|
import re |
|
|
|
|
|
secret_pattern = re.compile(r'<secret>(.*?)</secret>') |
|
|
|
|
|
def replace_secrets(value): |
|
|
if isinstance(value, str): |
|
|
matches = secret_pattern.findall(value) |
|
|
for placeholder in matches: |
|
|
if placeholder in sensitive_data: |
|
|
value = value.replace(f'<secret>{placeholder}</secret>', sensitive_data[placeholder]) |
|
|
return value |
|
|
elif isinstance(value, dict): |
|
|
return {k: replace_secrets(v) for k, v in value.items()} |
|
|
elif isinstance(value, list): |
|
|
return [replace_secrets(v) for v in value] |
|
|
return value |
|
|
|
|
|
for key, value in params.model_dump().items(): |
|
|
params.__dict__[key] = replace_secrets(value) |
|
|
return params |
|
|
|
|
|
@time_execution_sync('--create_action_model') |
|
|
def create_action_model(self, include_actions: Optional[list[str]] = None) -> Type[ActionModel]: |
|
|
"""Creates a Pydantic model from registered actions""" |
|
|
fields = { |
|
|
name: ( |
|
|
Optional[action.param_model], |
|
|
Field(default=None, description=action.description), |
|
|
) |
|
|
for name, action in self.registry.actions.items() |
|
|
if include_actions is None or name in include_actions |
|
|
} |
|
|
|
|
|
self.telemetry.capture( |
|
|
ControllerRegisteredFunctionsTelemetryEvent( |
|
|
registered_functions=[ |
|
|
RegisteredFunction(name=name, params=action.param_model.model_json_schema()) |
|
|
for name, action in self.registry.actions.items() |
|
|
if include_actions is None or name in include_actions |
|
|
] |
|
|
) |
|
|
) |
|
|
|
|
|
return create_model('ActionModel', __base__=ActionModel, **fields) |
|
|
|
|
|
def get_prompt_description(self) -> str: |
|
|
"""Get a description of all actions for the prompt""" |
|
|
return self.registry.get_prompt_description() |
|
|
|