from typing import List from functools import wraps class ModuleRegistry: def __init__(self): self.module_dict = {} def register_module(self, cls_name: str, cls): if cls_name in self.module_dict: raise ValueError(f"Found duplicate module: `{cls_name}`!") self.module_dict[cls_name] = cls def get_module(self, cls_name: str): if cls_name not in self.module_dict: raise ValueError(f"module `{cls_name}` not Found!") return self.module_dict[cls_name] def has_module(self, cls_name: str) -> bool: return cls_name in self.module_dict MODULE_REGISTRY = ModuleRegistry() def register_module(cls_name, cls): MODULE_REGISTRY.register_module(cls_name=cls_name, cls=cls) class ModelRegistry: def __init__(self): self.models = {} self.model_configs = {} def register(self, key: str, model_cls, config_cls): if key in self.models: raise ValueError(f"model name '{key}' is already registered!") self.models[key] = model_cls self.model_configs[key] = config_cls def key_error_message(self, key: str): error_message = f"""`{key}` is not a registered model name. Currently availabel model names: {self.get_model_names()}. If `{key}` is a customized model, you should use @register_model({key}) to register the model.""" return error_message def get_model(self, key: str): model = self.models.get(key, None) if model is None: raise KeyError(self.key_error_message(key)) return model def get_model_config(self, key: str): config = self.model_configs.get(key, None) if config is None: raise KeyError(self.key_error_message(key)) return config def get_model_names(self): return list(self.models.keys()) MODEL_REGISTRY = ModelRegistry() def register_model(config_cls, alias: List[str]=None): def decorator(cls): class_name = cls.__name__ MODEL_REGISTRY.register(class_name, cls, config_cls) if alias is not None: for alia in alias: MODEL_REGISTRY.register(alia, cls, config_cls) return cls return decorator class ParseFunctionRegistry: def __init__(self): self.functions = {} def register(self, func_name: str, func): """Register a function with a given name. Args: func_name: The name to register the function under func (Callable): The function to register Raises: ValueError: If a function with the same name is already registered """ if func_name in self.functions: raise ValueError(f"Function name '{func_name}' is already registered!") self.functions[func_name] = func def get_function(self, func_name: str) -> callable: """Get a registered function by name. Args: func_name: The name of the function to retrieve Returns: Callable: The registered function Raises: KeyError: If no function with the given name is registered """ if func_name not in self.functions: available_funcs = list(self.functions.keys()) raise KeyError(f"Function '{func_name}' not found! Available functions: {available_funcs}") return self.functions[func_name] def has_function(self, func_name: str) -> bool: """Check if a function name is registered. Args: func_name: The name to check Returns: True if the function name is registered, False otherwise """ return func_name in self.functions PARSE_FUNCTION_REGISTRY = ParseFunctionRegistry() def register_parse_function(func): @wraps(func) def wrapper(*args, **kwargs): return func(*args, **kwargs) PARSE_FUNCTION_REGISTRY.register(func.__name__, wrapper) return wrapper class ActionFunctionRegistry: def __init__(self): self.functions = {} def register(self, func_name: str, func): """Register a function with a given name. Args: func_name: The name to register the function under func (Callable): The function to register Raises: ValueError: If a function with the same name is already registered """ if func_name in self.functions: raise ValueError(f"Function name '{func_name}' is already registered!") self.functions[func_name] = func def get_function(self, func_name: str) -> callable: """Get a registered function by name. Args: func_name: The name of the function to retrieve Returns: Callable: The registered function Raises: KeyError: If no function with the given name is registered """ if func_name not in self.functions: available_funcs = list(self.functions.keys()) raise KeyError(f"Function '{func_name}' not found! Available functions: {available_funcs}") return self.functions[func_name] def has_function(self, func_name: str) -> bool: """Check if a function name is registered. Args: func_name: The name to check Returns: True if the function name is registered, False otherwise """ return func_name in self.functions ACTION_FUNCTION_REGISTRY = ActionFunctionRegistry() def register_action_function(func): """Register a function for ActionAgent serialization. Args: func (Callable): The function to register Returns: Callable: The original function (for decorator usage) """ @wraps(func) def wrapper(*args, **kwargs): return func(*args, **kwargs) ACTION_FUNCTION_REGISTRY.register(func.__name__, wrapper) return wrapper