| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | """The definition of kernel registry. |
| | |
| | Init Phase: |
| | 1. Define kernel registry. |
| | 2. Register kernels. |
| | |
| | """ |
| |
|
| | from ....accelerator.helper import get_current_accelerator |
| | from .base import BaseKernel |
| |
|
| |
|
| | __all__ = ["Registry", "register_kernel"] |
| |
|
| |
|
| | class Registry: |
| | """Registry for managing kernel implementations. |
| | |
| | Storage structure: ``{ "kernel_id": Class }`` |
| | """ |
| |
|
| | _kernels: dict[str, type[BaseKernel]] = {} |
| |
|
| | @classmethod |
| | def register(cls, kernel_cls: type[BaseKernel]) -> type[BaseKernel] | None: |
| | """Decorator to register a kernel class. |
| | |
| | The class must inherit from :class:`BaseKernel` and specify ``_kernel_id`` and ``_device`` attributes. |
| | |
| | Args: |
| | kernel_cls (type[BaseKernel]): The kernel class to register. |
| | |
| | Returns: |
| | type[BaseKernel] | None: The registered kernel class if the device type matches the current accelerator |
| | |
| | Raises: |
| | TypeError: If the class does not inherit from :class:`BaseKernel`. |
| | ValueError: If the kernel ID is missing or already registered. |
| | """ |
| | if not issubclass(kernel_cls, BaseKernel): |
| | raise TypeError(f"Class {kernel_cls} must inherit from BaseKernel") |
| |
|
| | kernel_id = kernel_cls.get_kernel_id() |
| | device = kernel_cls.get_device() |
| |
|
| | |
| | if device != get_current_accelerator().type: |
| | return |
| |
|
| | if not kernel_id: |
| | raise ValueError(f"Kernel ID (_kernel_id) is needed for {kernel_cls} to register") |
| |
|
| | if kernel_id in cls._kernels: |
| | raise ValueError(f"{kernel_id} already registered! The registered kernel is {cls._kernels[kernel_id]}") |
| |
|
| | cls._kernels[kernel_id] = kernel_cls |
| | return kernel_cls |
| |
|
| | @classmethod |
| | def get(cls, kernel_id: str) -> type[BaseKernel] | None: |
| | """Retrieves a registered kernel implementation by its ID. |
| | |
| | Args: |
| | kernel_id (str): The ID of the kernel to retrieve. |
| | |
| | Returns: |
| | type[BaseKernel] | None: The kernel class if found, else ``None``. |
| | """ |
| | return cls._kernels.get(kernel_id) |
| |
|
| | @classmethod |
| | def get_registered_kernels(cls) -> dict[str, type[BaseKernel]]: |
| | """Returns a dictionary of all registered kernels. |
| | |
| | Returns: |
| | dict[str, type[BaseKernel]]: Dictionary mapping kernel IDs to kernel classes. |
| | """ |
| | return cls._kernels |
| |
|
| |
|
| | |
| | register_kernel = Registry.register |
| |
|