| import importlib |
| from functools import partial |
| from pathlib import Path |
| from typing import Any, Callable, Dict, List |
|
|
| CATEGORIES = [ |
| "prompt", |
| "llm", |
| "node", |
| "worker", |
| "tool", |
| "encoder", |
| "connector", |
| "component", |
| ] |
|
|
|
|
| class Registry: |
| """Class for module registration and retrieval.""" |
|
|
| def __init__(self): |
| |
| self.mapping = {key: {} for key in CATEGORIES} |
|
|
| def __getattr__(self, name: str) -> Callable: |
| if name.startswith(("register_", "get_")): |
| prefix, category = name.split("_", 1) |
| if category in CATEGORIES: |
| if prefix == "register": |
| return partial(self.register, category) |
| elif prefix == "get": |
| return partial(self.get, category) |
| raise AttributeError( |
| f"'{self.__class__.__name__}' object has no attribute '{name}'" |
| ) |
|
|
| def _register(self, category: str, name: str = None): |
| """ |
| Registers a module under a specific category. |
| |
| :param category: The category to register the module under. |
| :param name: The name to register the module as. |
| """ |
|
|
| def wrap(module): |
| nonlocal name |
| name = name or module.__name__ |
| if name in self.mapping[category]: |
| raise ValueError( |
| f"Module {name} [{self.mapping[category].get(name)}] already registered in category {category}. Please use a different class name." |
| ) |
| self.mapping.setdefault(category, {})[name] = module |
| return module |
|
|
| return wrap |
|
|
| def _get(self, category: str, name: str): |
| """ |
| Retrieves a module from a specified category. |
| |
| :param category: The category to search in. |
| :param name: The name of the module to retrieve. |
| :raises KeyError: If the module is not found. |
| """ |
| try: |
| return self.mapping[category][name] |
| except KeyError: |
| raise KeyError(f"Module {name} not found in category {category}") |
|
|
| def register(self, category: str, name: str = None): |
| """ |
| Registers a module under a general category. |
| |
| :param category: The category to register the module under. |
| :param name: Optional name to register the module as. |
| """ |
| return self._register(category, name) |
|
|
| def get(self, category: str, name: str): |
| """ |
| Retrieves a module from a general category. |
| |
| :param category: The category to search in. |
| :param name: The name of the module to retrieve. |
| """ |
| return self._get(category, name) |
|
|
| def import_module(self, project_path: List[str] | str = None): |
| """Import modules from default paths and optional project paths. |
| |
| Args: |
| project_path: Optional path or list of paths to import modules from |
| """ |
| |
| root_path = Path(__file__).parents[1] |
| default_path = [ |
| root_path.joinpath("models"), |
| root_path.joinpath("tool_system"), |
| root_path.joinpath("services"), |
| root_path.joinpath("memories"), |
| root_path.joinpath("advanced_components"), |
| root_path.joinpath("clients"), |
| ] |
|
|
| for path in default_path: |
| for module in path.rglob("*.[ps][yo]"): |
| if module.name == "workflow.py": |
| continue |
| module = str(module) |
| if "__init__" in module or "base.py" in module or "entry.py" in module: |
| continue |
| module = "omagent_core" + module.rsplit("omagent_core", 1)[1].rsplit( |
| ".", 1 |
| )[0].replace("/", ".") |
| importlib.import_module(module) |
|
|
| |
| if project_path: |
| if isinstance(project_path, (str, Path)): |
| project_path = [project_path] |
|
|
| for path in project_path: |
| path = Path(path).absolute() |
| project_root = path.parent |
| for module in path.rglob("*.[ps][yo]"): |
| module = str(module) |
| if "__init__" in module: |
| continue |
| module = ( |
| module.replace(str(project_root) + "/", "") |
| .rsplit(".", 1)[0] |
| .replace("/", ".") |
| ) |
| importlib.import_module(module) |
|
|
|
|
| |
| registry = Registry() |
|
|
|
|
| if __name__ == "__main__": |
|
|
| @registry.register_node() |
| class TestNode: |
| name: "TestNode" |
|
|
| print(registry.get_node("TestNode")) |
|
|