Spaces:
Sleeping
Sleeping
| import asyncio | |
| import functools | |
| import inspect | |
| import logging | |
| import re | |
| from collections.abc import Callable | |
| from inspect import Parameter, iscoroutinefunction, signature | |
| from types import UnionType | |
| from typing import Any, Generic, Optional, TypeVar, Union, get_args, get_origin | |
| import pyotp | |
| from pydantic import BaseModel, Field, RootModel, create_model | |
| from browser_use.browser import BrowserSession | |
| from browser_use.filesystem.file_system import FileSystem | |
| from browser_use.llm.base import BaseChatModel | |
| from browser_use.observability import observe_debug | |
| from browser_use.telemetry.service import ProductTelemetry | |
| from browser_use.tools.registry.views import ( | |
| ActionModel, | |
| ActionRegistry, | |
| RegisteredAction, | |
| SpecialActionParameters, | |
| ) | |
| from browser_use.utils import is_new_tab_page, match_url_with_domain_pattern, time_execution_async | |
| Context = TypeVar('Context') | |
| logger = logging.getLogger(__name__) | |
| class Registry(Generic[Context]): | |
| """Service for registering and managing actions""" | |
| def __init__(self, exclude_actions: list[str] | None = None): | |
| self.registry = ActionRegistry() | |
| self.telemetry = ProductTelemetry() | |
| self.exclude_actions = exclude_actions if exclude_actions is not None else [] | |
| def _get_special_param_types(self) -> dict[str, type | UnionType | None]: | |
| """Get the expected types for special parameters from SpecialActionParameters""" | |
| # Manually define the expected types to avoid issues with Optional handling. | |
| # we should try to reduce this list to 0 if possible, give as few standardized objects to all the actions | |
| # but each driver should decide what is relevant to expose the action methods, | |
| # e.g. CDP client, 2fa code getters, sensitive_data wrappers, other context, etc. | |
| return { | |
| 'context': None, # Context is a TypeVar, so we can't validate type | |
| 'browser_session': BrowserSession, | |
| 'page_url': str, | |
| 'cdp_client': None, # CDPClient type from cdp_use, but we don't import it here | |
| 'page_extraction_llm': BaseChatModel, | |
| 'available_file_paths': list, | |
| 'has_sensitive_data': bool, | |
| 'file_system': FileSystem, | |
| } | |
| def _normalize_action_function_signature( | |
| self, | |
| func: Callable, | |
| description: str, | |
| param_model: type[BaseModel] | None = None, | |
| ) -> tuple[Callable, type[BaseModel]]: | |
| """ | |
| Normalize action function to accept only kwargs. | |
| Returns: | |
| - Normalized function that accepts (*_, params: ParamModel, **special_params) | |
| - The param model to use for registration | |
| """ | |
| sig = signature(func) | |
| parameters = list(sig.parameters.values()) | |
| special_param_types = self._get_special_param_types() | |
| special_param_names = set(special_param_types.keys()) | |
| # Step 1: Validate no **kwargs in original function signature | |
| # if it needs default values it must use a dedicated param_model: BaseModel instead | |
| for param in parameters: | |
| if param.kind == Parameter.VAR_KEYWORD: | |
| raise ValueError( | |
| f"Action '{func.__name__}' has **{param.name} which is not allowed. " | |
| f'Actions must have explicit positional parameters only.' | |
| ) | |
| # Step 2: Separate special and action parameters | |
| action_params = [] | |
| special_params = [] | |
| param_model_provided = param_model is not None | |
| for i, param in enumerate(parameters): | |
| # Check if this is a Type 1 pattern (first param is BaseModel) | |
| if i == 0 and param_model_provided and param.name not in special_param_names: | |
| # This is Type 1 pattern - skip the params argument | |
| continue | |
| if param.name in special_param_names: | |
| # Validate special parameter type | |
| expected_type = special_param_types.get(param.name) | |
| if param.annotation != Parameter.empty and expected_type is not None: | |
| # Handle Optional types - normalize both sides | |
| param_type = param.annotation | |
| origin = get_origin(param_type) | |
| if origin is Union: | |
| args = get_args(param_type) | |
| # Find non-None type | |
| param_type = next((arg for arg in args if arg is not type(None)), param_type) | |
| # Check if types are compatible (exact match, subclass, or generic list) | |
| types_compatible = ( | |
| param_type == expected_type | |
| or ( | |
| inspect.isclass(param_type) | |
| and inspect.isclass(expected_type) | |
| and issubclass(param_type, expected_type) | |
| ) | |
| or | |
| # Handle list[T] vs list comparison | |
| (expected_type is list and (param_type is list or get_origin(param_type) is list)) | |
| ) | |
| if not types_compatible: | |
| expected_type_name = getattr(expected_type, '__name__', str(expected_type)) | |
| param_type_name = getattr(param_type, '__name__', str(param_type)) | |
| raise ValueError( | |
| f"Action '{func.__name__}' parameter '{param.name}: {param_type_name}' " | |
| f"conflicts with special argument injected by tools: '{param.name}: {expected_type_name}'" | |
| ) | |
| special_params.append(param) | |
| else: | |
| action_params.append(param) | |
| # Step 3: Create or validate param model | |
| if not param_model_provided: | |
| # Type 2: Generate param model from action params | |
| if action_params: | |
| params_dict = {} | |
| for param in action_params: | |
| annotation = param.annotation if param.annotation != Parameter.empty else str | |
| default = ... if param.default == Parameter.empty else param.default | |
| params_dict[param.name] = (annotation, default) | |
| param_model = create_model(f'{func.__name__}_Params', __base__=ActionModel, **params_dict) | |
| else: | |
| # No action params, create empty model | |
| param_model = create_model( | |
| f'{func.__name__}_Params', | |
| __base__=ActionModel, | |
| ) | |
| assert param_model is not None, f'param_model is None for {func.__name__}' | |
| # Step 4: Create normalized wrapper function | |
| async def normalized_wrapper(*args, params: BaseModel | None = None, **kwargs): | |
| """Normalized action that only accepts kwargs""" | |
| # Validate no positional args | |
| if args: | |
| raise TypeError(f'{func.__name__}() does not accept positional arguments, only keyword arguments are allowed') | |
| # Prepare arguments for original function | |
| call_args = [] | |
| call_kwargs = {} | |
| # Handle Type 1 pattern (first arg is the param model) | |
| if param_model_provided and parameters and parameters[0].name not in special_param_names: | |
| if params is None: | |
| raise ValueError(f"{func.__name__}() missing required 'params' argument") | |
| # For Type 1, we'll use the params object as first argument | |
| pass | |
| else: | |
| # Type 2 pattern - need to unpack params | |
| # If params is None, try to create it from kwargs | |
| if params is None and action_params: | |
| # Extract action params from kwargs | |
| action_kwargs = {} | |
| for param in action_params: | |
| if param.name in kwargs: | |
| action_kwargs[param.name] = kwargs[param.name] | |
| if action_kwargs: | |
| # Use the param_model which has the correct types defined | |
| params = param_model(**action_kwargs) | |
| # Build call_args by iterating through original function parameters in order | |
| params_dict = params.model_dump() if params is not None else {} | |
| for i, param in enumerate(parameters): | |
| # Skip first param for Type 1 pattern (it's the model itself) | |
| if param_model_provided and i == 0 and param.name not in special_param_names: | |
| call_args.append(params) | |
| elif param.name in special_param_names: | |
| # This is a special parameter | |
| if param.name in kwargs: | |
| value = kwargs[param.name] | |
| # Check if required special param is None | |
| if value is None and param.default == Parameter.empty: | |
| if param.name == 'browser_session': | |
| raise ValueError(f'Action {func.__name__} requires browser_session but none provided.') | |
| elif param.name == 'page_extraction_llm': | |
| raise ValueError(f'Action {func.__name__} requires page_extraction_llm but none provided.') | |
| elif param.name == 'file_system': | |
| raise ValueError(f'Action {func.__name__} requires file_system but none provided.') | |
| elif param.name == 'page': | |
| raise ValueError(f'Action {func.__name__} requires page but none provided.') | |
| elif param.name == 'available_file_paths': | |
| raise ValueError(f'Action {func.__name__} requires available_file_paths but none provided.') | |
| elif param.name == 'file_system': | |
| raise ValueError(f'Action {func.__name__} requires file_system but none provided.') | |
| else: | |
| raise ValueError(f"{func.__name__}() missing required special parameter '{param.name}'") | |
| call_args.append(value) | |
| elif param.default != Parameter.empty: | |
| call_args.append(param.default) | |
| else: | |
| # Special param is required but not provided | |
| if param.name == 'browser_session': | |
| raise ValueError(f'Action {func.__name__} requires browser_session but none provided.') | |
| elif param.name == 'page_extraction_llm': | |
| raise ValueError(f'Action {func.__name__} requires page_extraction_llm but none provided.') | |
| elif param.name == 'file_system': | |
| raise ValueError(f'Action {func.__name__} requires file_system but none provided.') | |
| elif param.name == 'page': | |
| raise ValueError(f'Action {func.__name__} requires page but none provided.') | |
| elif param.name == 'available_file_paths': | |
| raise ValueError(f'Action {func.__name__} requires available_file_paths but none provided.') | |
| elif param.name == 'file_system': | |
| raise ValueError(f'Action {func.__name__} requires file_system but none provided.') | |
| else: | |
| raise ValueError(f"{func.__name__}() missing required special parameter '{param.name}'") | |
| else: | |
| # This is an action parameter | |
| if param.name in params_dict: | |
| call_args.append(params_dict[param.name]) | |
| elif param.default != Parameter.empty: | |
| call_args.append(param.default) | |
| else: | |
| raise ValueError(f"{func.__name__}() missing required parameter '{param.name}'") | |
| # Call original function with positional args | |
| if iscoroutinefunction(func): | |
| return await func(*call_args) | |
| else: | |
| return await asyncio.to_thread(func, *call_args) | |
| # Update wrapper signature to be kwargs-only | |
| new_params = [Parameter('params', Parameter.KEYWORD_ONLY, default=None, annotation=Optional[param_model])] | |
| # Add special params as keyword-only | |
| for sp in special_params: | |
| new_params.append(Parameter(sp.name, Parameter.KEYWORD_ONLY, default=sp.default, annotation=sp.annotation)) | |
| # Add **kwargs to accept and ignore extra params | |
| new_params.append(Parameter('kwargs', Parameter.VAR_KEYWORD)) | |
| normalized_wrapper.__signature__ = sig.replace(parameters=new_params) # type: ignore[attr-defined] | |
| return normalized_wrapper, param_model | |
| # @time_execution_sync('--create_param_model') | |
| def _create_param_model(self, function: Callable) -> type[BaseModel]: | |
| """Creates a Pydantic model from function signature""" | |
| sig = signature(function) | |
| special_param_names = set(SpecialActionParameters.model_fields.keys()) | |
| params = { | |
| name: (param.annotation, ... if param.default == param.empty else param.default) | |
| for name, param in sig.parameters.items() | |
| if name not in special_param_names | |
| } | |
| # TODO: make the types here work | |
| return create_model( | |
| f'{function.__name__}_parameters', | |
| __base__=ActionModel, | |
| **params, # type: ignore | |
| ) | |
| def action( | |
| self, | |
| description: str, | |
| param_model: type[BaseModel] | None = None, | |
| domains: list[str] | None = None, | |
| allowed_domains: list[str] | None = None, | |
| ): | |
| """Decorator for registering actions""" | |
| # Handle aliases: domains and allowed_domains are the same parameter | |
| if allowed_domains is not None and domains is not None: | |
| raise ValueError("Cannot specify both 'domains' and 'allowed_domains' - they are aliases for the same parameter") | |
| final_domains = allowed_domains if allowed_domains is not None else domains | |
| def decorator(func: Callable): | |
| # Skip registration if action is in exclude_actions | |
| if func.__name__ in self.exclude_actions: | |
| return func | |
| # Normalize the function signature | |
| normalized_func, actual_param_model = self._normalize_action_function_signature(func, description, param_model) | |
| action = RegisteredAction( | |
| name=func.__name__, | |
| description=description, | |
| function=normalized_func, | |
| param_model=actual_param_model, | |
| domains=final_domains, | |
| ) | |
| self.registry.actions[func.__name__] = action | |
| # Return the normalized function so it can be called with kwargs | |
| return normalized_func | |
| return decorator | |
| async def execute_action( | |
| self, | |
| action_name: str, | |
| params: dict, | |
| browser_session: BrowserSession | None = None, | |
| page_extraction_llm: BaseChatModel | None = None, | |
| file_system: FileSystem | None = None, | |
| sensitive_data: dict[str, str | dict[str, str]] | None = None, | |
| available_file_paths: list[str] | None = None, | |
| ) -> Any: | |
| """Execute a registered action with simplified parameter handling""" | |
| if action_name not in self.registry.actions: | |
| raise ValueError(f'Action {action_name} not found') | |
| action = self.registry.actions[action_name] | |
| try: | |
| # Create the validated Pydantic model | |
| try: | |
| validated_params = action.param_model(**params) | |
| except Exception as e: | |
| raise ValueError(f'Invalid parameters {params} for action {action_name}: {type(e)}: {e}') from e | |
| if sensitive_data: | |
| # Get current URL if browser_session is provided | |
| current_url = None | |
| if browser_session and browser_session.current_target_id: | |
| try: | |
| # Get current page info using CDP | |
| targets = await browser_session.cdp_client.send.Target.getTargets() | |
| for target in targets.get('targetInfos', []): | |
| if target.get('targetId') == browser_session.current_target_id: | |
| current_url = target.get('url') | |
| break | |
| except Exception: | |
| pass | |
| validated_params = self._replace_sensitive_data(validated_params, sensitive_data, current_url) | |
| # Build special context dict | |
| special_context = { | |
| 'browser_session': browser_session, | |
| 'page_extraction_llm': page_extraction_llm, | |
| 'available_file_paths': available_file_paths, | |
| 'has_sensitive_data': action_name == 'input' and bool(sensitive_data), | |
| 'file_system': file_system, | |
| } | |
| # Only pass sensitive_data to actions that explicitly need it (input) | |
| if action_name == 'input': | |
| special_context['sensitive_data'] = sensitive_data | |
| # Add CDP-related parameters if browser_session is available | |
| if browser_session: | |
| # Add page_url | |
| try: | |
| special_context['page_url'] = await browser_session.get_current_page_url() | |
| except Exception: | |
| special_context['page_url'] = None | |
| # Add cdp_client | |
| special_context['cdp_client'] = browser_session.cdp_client | |
| # All functions are now normalized to accept kwargs only | |
| # Call with params and unpacked special context | |
| try: | |
| return await action.function(params=validated_params, **special_context) | |
| except Exception as e: | |
| raise | |
| except ValueError as e: | |
| # Preserve ValueError messages from validation | |
| if 'requires browser_session but none provided' in str(e) or 'requires page_extraction_llm but none provided' in str( | |
| e | |
| ): | |
| raise RuntimeError(str(e)) from e | |
| else: | |
| raise RuntimeError(f'Error executing action {action_name}: {str(e)}') from e | |
| except TimeoutError as e: | |
| raise RuntimeError(f'Error executing action {action_name} due to timeout.') from e | |
| except Exception as e: | |
| raise RuntimeError(f'Error executing action {action_name}: {str(e)}') from e | |
| def _log_sensitive_data_usage(self, placeholders_used: set[str], current_url: str | None) -> None: | |
| """Log when sensitive data is being used on a page""" | |
| if placeholders_used: | |
| url_info = f' on {current_url}' if current_url and not is_new_tab_page(current_url) else '' | |
| logger.info(f'🔒 Using sensitive data placeholders: {", ".join(sorted(placeholders_used))}{url_info}') | |
| def _replace_sensitive_data( | |
| self, params: BaseModel, sensitive_data: dict[str, Any], current_url: str | None = None | |
| ) -> BaseModel: | |
| """ | |
| Replaces sensitive data placeholders in params with actual values. | |
| Args: | |
| params: The parameter object containing <secret>placeholder</secret> tags | |
| sensitive_data: Dictionary of sensitive data, either in old format {key: value} | |
| or new format {domain_pattern: {key: value}} | |
| current_url: Optional current URL for domain matching | |
| Returns: | |
| BaseModel: The parameter object with placeholders replaced by actual values | |
| """ | |
| secret_pattern = re.compile(r'<secret>(.*?)</secret>') | |
| # Set to track all missing placeholders across the full object | |
| all_missing_placeholders = set() | |
| # Set to track successfully replaced placeholders | |
| replaced_placeholders = set() | |
| # Process sensitive data based on format and current URL | |
| applicable_secrets = {} | |
| for domain_or_key, content in sensitive_data.items(): | |
| if isinstance(content, dict): | |
| # New format: {domain_pattern: {key: value}} | |
| # Only include secrets for domains that match the current URL | |
| if current_url and not is_new_tab_page(current_url): | |
| # it's a real url, check it using our custom allowed_domains scheme://*.example.com glob matching | |
| if match_url_with_domain_pattern(current_url, domain_or_key): | |
| applicable_secrets.update(content) | |
| else: | |
| # Old format: {key: value}, expose to all domains (only allowed for legacy reasons) | |
| applicable_secrets[domain_or_key] = content | |
| # Filter out empty values | |
| applicable_secrets = {k: v for k, v in applicable_secrets.items() if v} | |
| def recursively_replace_secrets(value: str | dict | list) -> str | dict | list: | |
| if isinstance(value, str): | |
| matches = secret_pattern.findall(value) | |
| # check if the placeholder key, like x_password is in the output parameters of the LLM and replace it with the sensitive data | |
| for placeholder in matches: | |
| if placeholder in applicable_secrets: | |
| # generate a totp code if secret is a 2fa secret | |
| if 'bu_2fa_code' in placeholder: | |
| totp = pyotp.TOTP(applicable_secrets[placeholder], digits=6) | |
| replacement_value = totp.now() | |
| else: | |
| replacement_value = applicable_secrets[placeholder] | |
| value = value.replace(f'<secret>{placeholder}</secret>', replacement_value) | |
| replaced_placeholders.add(placeholder) | |
| else: | |
| # Keep track of missing placeholders | |
| all_missing_placeholders.add(placeholder) | |
| # Don't replace the tag, keep it as is | |
| return value | |
| elif isinstance(value, dict): | |
| return {k: recursively_replace_secrets(v) for k, v in value.items()} | |
| elif isinstance(value, list): | |
| return [recursively_replace_secrets(v) for v in value] | |
| return value | |
| params_dump = params.model_dump() | |
| processed_params = recursively_replace_secrets(params_dump) | |
| # Log sensitive data usage | |
| self._log_sensitive_data_usage(replaced_placeholders, current_url) | |
| # Log a warning if any placeholders are missing | |
| if all_missing_placeholders: | |
| logger.warning(f'Missing or empty keys in sensitive_data dictionary: {", ".join(all_missing_placeholders)}') | |
| return type(params).model_validate(processed_params) | |
| # @time_execution_sync('--create_action_model') | |
| def create_action_model(self, include_actions: list[str] | None = None, page_url: str | None = None) -> type[ActionModel]: | |
| """Creates a Union of individual action models from registered actions, | |
| used by LLM APIs that support tool calling & enforce a schema. | |
| Each action model contains only the specific action being used, | |
| rather than all actions with most set to None. | |
| """ | |
| from typing import Union | |
| # Filter actions based on page_url if provided: | |
| # if page_url is None, only include actions with no filters | |
| # if page_url is provided, only include actions that match the URL | |
| available_actions: dict[str, RegisteredAction] = {} | |
| for name, action in self.registry.actions.items(): | |
| if include_actions is not None and name not in include_actions: | |
| continue | |
| # If no page_url provided, only include actions with no filters | |
| if page_url is None: | |
| if action.domains is None: | |
| available_actions[name] = action | |
| continue | |
| # Check domain filter if present | |
| domain_is_allowed = self.registry._match_domains(action.domains, page_url) | |
| # Include action if domain filter matches | |
| if domain_is_allowed: | |
| available_actions[name] = action | |
| # Create individual action models for each action | |
| individual_action_models: list[type[BaseModel]] = [] | |
| for name, action in available_actions.items(): | |
| # Create an individual model for each action that contains only one field | |
| individual_model = create_model( | |
| f'{name.title().replace("_", "")}ActionModel', | |
| __base__=ActionModel, | |
| **{ | |
| name: ( | |
| action.param_model, | |
| Field(description=action.description), | |
| ) # type: ignore | |
| }, | |
| ) | |
| individual_action_models.append(individual_model) | |
| # If no actions available, return empty ActionModel | |
| if not individual_action_models: | |
| return create_model('EmptyActionModel', __base__=ActionModel) | |
| # Create proper Union type that maintains ActionModel interface | |
| if len(individual_action_models) == 1: | |
| # If only one action, return it directly (no Union needed) | |
| result_model = individual_action_models[0] | |
| # Meaning the length is more than 1 | |
| else: | |
| # Create a Union type using RootModel that properly delegates ActionModel methods | |
| union_type = Union[tuple(individual_action_models)] # type: ignore : Typing doesn't understand that the length is >= 2 (by design) | |
| class ActionModelUnion(RootModel[union_type]): # type: ignore | |
| def get_index(self) -> int | None: | |
| """Delegate get_index to the underlying action model""" | |
| if hasattr(self.root, 'get_index'): | |
| return self.root.get_index() # type: ignore | |
| return None | |
| def set_index(self, index: int): | |
| """Delegate set_index to the underlying action model""" | |
| if hasattr(self.root, 'set_index'): | |
| self.root.set_index(index) # type: ignore | |
| def model_dump(self, **kwargs): | |
| """Delegate model_dump to the underlying action model""" | |
| if hasattr(self.root, 'model_dump'): | |
| return self.root.model_dump(**kwargs) # type: ignore | |
| return super().model_dump(**kwargs) | |
| # Set the name for better debugging | |
| ActionModelUnion.__name__ = 'ActionModel' | |
| ActionModelUnion.__qualname__ = 'ActionModel' | |
| result_model = ActionModelUnion | |
| return result_model # type:ignore | |
| def get_prompt_description(self, page_url: str | None = None) -> str: | |
| """Get a description of all actions for the prompt | |
| If page_url is provided, only include actions that are available for that URL | |
| based on their domain filters | |
| """ | |
| return self.registry.get_prompt_description(page_url=page_url) | |