iLOVE2D's picture
Upload 2846 files
5374a2d verified
from typing import List
from functools import wraps
class ModuleRegistry:
def __init__(self):
self.module_dict = {}
def register_module(self, cls_name: str, cls):
if cls_name in self.module_dict:
raise ValueError(f"Found duplicate module: `{cls_name}`!")
self.module_dict[cls_name] = cls
def get_module(self, cls_name: str):
if cls_name not in self.module_dict:
raise ValueError(f"module `{cls_name}` not Found!")
return self.module_dict[cls_name]
def has_module(self, cls_name: str) -> bool:
return cls_name in self.module_dict
MODULE_REGISTRY = ModuleRegistry()
def register_module(cls_name, cls):
MODULE_REGISTRY.register_module(cls_name=cls_name, cls=cls)
class ModelRegistry:
def __init__(self):
self.models = {}
self.model_configs = {}
def register(self, key: str, model_cls, config_cls):
if key in self.models:
raise ValueError(f"model name '{key}' is already registered!")
self.models[key] = model_cls
self.model_configs[key] = config_cls
def key_error_message(self, key: str):
error_message = f"""`{key}` is not a registered model name. Currently availabel model names: {self.get_model_names()}. If `{key}` is a customized model, you should use @register_model({key}) to register the model."""
return error_message
def get_model(self, key: str):
model = self.models.get(key, None)
if model is None:
raise KeyError(self.key_error_message(key))
return model
def get_model_config(self, key: str):
config = self.model_configs.get(key, None)
if config is None:
raise KeyError(self.key_error_message(key))
return config
def get_model_names(self):
return list(self.models.keys())
MODEL_REGISTRY = ModelRegistry()
def register_model(config_cls, alias: List[str]=None):
def decorator(cls):
class_name = cls.__name__
MODEL_REGISTRY.register(class_name, cls, config_cls)
if alias is not None:
for alia in alias:
MODEL_REGISTRY.register(alia, cls, config_cls)
return cls
return decorator
class ParseFunctionRegistry:
def __init__(self):
self.functions = {}
def register(self, func_name: str, func):
"""Register a function with a given name.
Args:
func_name: The name to register the function under
func (Callable): The function to register
Raises:
ValueError: If a function with the same name is already registered
"""
if func_name in self.functions:
raise ValueError(f"Function name '{func_name}' is already registered!")
self.functions[func_name] = func
def get_function(self, func_name: str) -> callable:
"""Get a registered function by name.
Args:
func_name: The name of the function to retrieve
Returns:
Callable: The registered function
Raises:
KeyError: If no function with the given name is registered
"""
if func_name not in self.functions:
available_funcs = list(self.functions.keys())
raise KeyError(f"Function '{func_name}' not found! Available functions: {available_funcs}")
return self.functions[func_name]
def has_function(self, func_name: str) -> bool:
"""Check if a function name is registered.
Args:
func_name: The name to check
Returns:
True if the function name is registered, False otherwise
"""
return func_name in self.functions
PARSE_FUNCTION_REGISTRY = ParseFunctionRegistry()
def register_parse_function(func):
@wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
PARSE_FUNCTION_REGISTRY.register(func.__name__, wrapper)
return wrapper
class ActionFunctionRegistry:
def __init__(self):
self.functions = {}
def register(self, func_name: str, func):
"""Register a function with a given name.
Args:
func_name: The name to register the function under
func (Callable): The function to register
Raises:
ValueError: If a function with the same name is already registered
"""
if func_name in self.functions:
raise ValueError(f"Function name '{func_name}' is already registered!")
self.functions[func_name] = func
def get_function(self, func_name: str) -> callable:
"""Get a registered function by name.
Args:
func_name: The name of the function to retrieve
Returns:
Callable: The registered function
Raises:
KeyError: If no function with the given name is registered
"""
if func_name not in self.functions:
available_funcs = list(self.functions.keys())
raise KeyError(f"Function '{func_name}' not found! Available functions: {available_funcs}")
return self.functions[func_name]
def has_function(self, func_name: str) -> bool:
"""Check if a function name is registered.
Args:
func_name: The name to check
Returns:
True if the function name is registered, False otherwise
"""
return func_name in self.functions
ACTION_FUNCTION_REGISTRY = ActionFunctionRegistry()
def register_action_function(func):
"""Register a function for ActionAgent serialization.
Args:
func (Callable): The function to register
Returns:
Callable: The original function (for decorator usage)
"""
@wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
ACTION_FUNCTION_REGISTRY.register(func.__name__, wrapper)
return wrapper