import asyncio from inspect import iscoroutinefunction, signature from typing import Any, Callable, Dict, Generic, Optional, Type, TypeVar from langchain_core.language_models.chat_models import BaseChatModel from pydantic import BaseModel, Field, create_model from browser_use.browser.context import BrowserContext from browser_use.controller.registry.views import ( ActionModel, ActionRegistry, RegisteredAction, ) from browser_use.telemetry.service import ProductTelemetry from browser_use.telemetry.views import ( ControllerRegisteredFunctionsTelemetryEvent, RegisteredFunction, ) from browser_use.utils import time_execution_async, time_execution_sync Context = TypeVar('Context') class Registry(Generic[Context]): """Service for registering and managing actions""" def __init__(self, exclude_actions: list[str] | None = None): self.registry = ActionRegistry() self.telemetry = ProductTelemetry() self.exclude_actions = exclude_actions if exclude_actions is not None else [] @time_execution_sync('--create_param_model') def _create_param_model(self, function: Callable) -> Type[BaseModel]: """Creates a Pydantic model from function signature""" sig = signature(function) params = { name: (param.annotation, ... if param.default == param.empty else param.default) for name, param in sig.parameters.items() if name != 'browser' and name != 'page_extraction_llm' and name != 'available_file_paths' } # TODO: make the types here work return create_model( f'{function.__name__}_parameters', __base__=ActionModel, **params, # type: ignore ) def action( self, description: str, param_model: Optional[Type[BaseModel]] = None, ): """Decorator for registering actions""" def decorator(func: Callable): # Skip registration if action is in exclude_actions if func.__name__ in self.exclude_actions: return func # Create param model from function if not provided actual_param_model = param_model or self._create_param_model(func) # Wrap sync functions to make them async if not iscoroutinefunction(func): async def async_wrapper(*args, **kwargs): return await asyncio.to_thread(func, *args, **kwargs) # Copy the signature and other metadata from the original function async_wrapper.__signature__ = signature(func) async_wrapper.__name__ = func.__name__ async_wrapper.__annotations__ = func.__annotations__ wrapped_func = async_wrapper else: wrapped_func = func action = RegisteredAction( name=func.__name__, description=description, function=wrapped_func, param_model=actual_param_model, ) self.registry.actions[func.__name__] = action return func return decorator @time_execution_async('--execute_action') async def execute_action( self, action_name: str, params: dict, browser: Optional[BrowserContext] = None, page_extraction_llm: Optional[BaseChatModel] = None, sensitive_data: Optional[Dict[str, str]] = None, available_file_paths: Optional[list[str]] = None, # context: Context | None = None, ) -> Any: """Execute a registered action""" if action_name not in self.registry.actions: raise ValueError(f'Action {action_name} not found') action = self.registry.actions[action_name] try: # Create the validated Pydantic model validated_params = action.param_model(**params) # Check if the first parameter is a Pydantic model sig = signature(action.function) parameters = list(sig.parameters.values()) is_pydantic = parameters and issubclass(parameters[0].annotation, BaseModel) parameter_names = [param.name for param in parameters] if sensitive_data: validated_params = self._replace_sensitive_data(validated_params, sensitive_data) # Check if the action requires browser if 'browser' in parameter_names and not browser: raise ValueError(f'Action {action_name} requires browser but none provided.') if 'page_extraction_llm' in parameter_names and not page_extraction_llm: raise ValueError(f'Action {action_name} requires page_extraction_llm but none provided.') if 'available_file_paths' in parameter_names and not available_file_paths: raise ValueError(f'Action {action_name} requires available_file_paths but none provided.') if 'context' in parameter_names and not context: raise ValueError(f'Action {action_name} requires context but none provided.') # Prepare arguments based on parameter type extra_args = {} if 'context' in parameter_names: extra_args['context'] = context if 'browser' in parameter_names: extra_args['browser'] = browser if 'page_extraction_llm' in parameter_names: extra_args['page_extraction_llm'] = page_extraction_llm if 'available_file_paths' in parameter_names: extra_args['available_file_paths'] = available_file_paths if action_name == 'input_text' and sensitive_data: extra_args['has_sensitive_data'] = True if is_pydantic: return await action.function(validated_params, **extra_args) return await action.function(**validated_params.model_dump(), **extra_args) except Exception as e: raise RuntimeError(f'Error executing action {action_name}: {str(e)}') from e def _replace_sensitive_data(self, params: BaseModel, sensitive_data: Dict[str, str]) -> BaseModel: """Replaces the sensitive data in the params""" # if there are any str with placeholder in the params, replace them with the actual value from sensitive_data import re secret_pattern = re.compile(r'(.*?)') def replace_secrets(value): if isinstance(value, str): matches = secret_pattern.findall(value) for placeholder in matches: if placeholder in sensitive_data: value = value.replace(f'{placeholder}', sensitive_data[placeholder]) return value elif isinstance(value, dict): return {k: replace_secrets(v) for k, v in value.items()} elif isinstance(value, list): return [replace_secrets(v) for v in value] return value for key, value in params.model_dump().items(): params.__dict__[key] = replace_secrets(value) return params @time_execution_sync('--create_action_model') def create_action_model(self, include_actions: Optional[list[str]] = None) -> Type[ActionModel]: """Creates a Pydantic model from registered actions""" fields = { name: ( Optional[action.param_model], Field(default=None, description=action.description), ) for name, action in self.registry.actions.items() if include_actions is None or name in include_actions } self.telemetry.capture( ControllerRegisteredFunctionsTelemetryEvent( registered_functions=[ RegisteredFunction(name=name, params=action.param_model.model_json_schema()) for name, action in self.registry.actions.items() if include_actions is None or name in include_actions ] ) ) return create_model('ActionModel', __base__=ActionModel, **fields) # type:ignore def get_prompt_description(self) -> str: """Get a description of all actions for the prompt""" return self.registry.get_prompt_description()