| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | """The definition of kernel interface. |
| | |
| | Init Phase: |
| | 1. Scan all kernels. |
| | 2. Register default kernels. |
| | 3. Define kernel plugin. |
| | |
| | """ |
| |
|
| | import importlib |
| | from pathlib import Path |
| |
|
| | from ....utils import logging |
| | from ....utils.plugin import BasePlugin |
| | from ....utils.types import HFModel |
| | from .registry import Registry |
| |
|
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| |
|
| | def scan_all_kernels(): |
| | """Scan all kernels in the ``ops`` directory. |
| | |
| | Scans the ``ops`` directory for all ``.py`` files and attempts to import them. |
| | Importing triggers the :func:`~registry.register_kernel` decorator, which automatically registers the kernels. |
| | |
| | Returns: |
| | dict[str, type[BaseKernel]]: A dictionary of registered kernels. |
| | |
| | .. note:: |
| | This function assumes that the ``ops`` directory is located in the same directory as this file. |
| | It recursively searches for ``.py`` files and constructs the module path for import. |
| | """ |
| | ops_path = Path(__file__).parent / "ops" |
| |
|
| | if not ops_path.exists(): |
| | return |
| |
|
| | base_package = __package__ |
| |
|
| | for file_path in ops_path.rglob("*.py"): |
| | if file_path.name == "__init__.py": |
| | continue |
| |
|
| | |
| | |
| | |
| | rel_path = file_path.relative_to(Path(__file__).parent) |
| |
|
| | |
| | module_name = ".".join(rel_path.parts)[:-3] |
| | full_module_name = f"{base_package}.{module_name}" |
| |
|
| | try: |
| | importlib.import_module(full_module_name) |
| | except Exception as e: |
| | logger.warning(f"[Kernel Registry] Failed to import {full_module_name} when loading kernels: {e}") |
| |
|
| | return Registry.get_registered_kernels() |
| |
|
| |
|
| | default_kernels = scan_all_kernels() |
| |
|
| |
|
| | def get_default_kernels(): |
| | """Get a list of default registered kernel IDs. |
| | |
| | Returns: |
| | list[str]: List of kernel IDs. |
| | """ |
| | return list(default_kernels.keys()) |
| |
|
| |
|
| | def apply_kernel(kernel_id: str, **kwargs): |
| | """Applies a specific kernel to the model. |
| | |
| | Args: |
| | kernel_id (str): The ID of the kernel to apply. |
| | **kwargs: Keyword arguments passed to the kernel application function. |
| | Typically includes the model instance. |
| | |
| | Returns: |
| | HFModel: The model with applied kernel. |
| | """ |
| | kernel = default_kernels.get(kernel_id) |
| | if kernel is None: |
| | raise ValueError(f"Kernel {kernel_id} not found") |
| |
|
| | kernel.apply(**kwargs) |
| |
|
| |
|
| | class KernelPlugin(BasePlugin): |
| | """Plugin for managing kernel optimizations.""" |
| |
|
| | pass |
| |
|
| |
|
| | @KernelPlugin("auto").register() |
| | def apply_default_kernels(model: HFModel, include_kernels: str = None) -> HFModel: |
| | """Applies all default registered kernels to the model. |
| | |
| | Args: |
| | model (HFModel): The model instance to apply kernels to. |
| | include_kernels (str, optional): Comma-separated list of kernel IDs to apply. |
| | If "auto" or True, applies all default kernels. |
| | If None or False, no kernels are applied. |
| | Defaults to None. |
| | |
| | Returns: |
| | HFModel: The model with applied kernels. |
| | """ |
| | if not include_kernels: |
| | return model |
| | elif include_kernels == "auto" or include_kernels is True: |
| | use_kernels = default_kernels.keys() |
| | else: |
| | use_kernels = include_kernels.split(",") |
| |
|
| | for kernel in use_kernels: |
| | if kernel not in default_kernels: |
| | raise ValueError(f"Kernel {kernel} not found") |
| |
|
| | apply_kernel(kernel, model=model) |
| |
|
| | return model |
| |
|