dongxx1104's picture
Upload folder using huggingface_hub
db704cb verified
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The definition of kernel registry.
Init Phase:
1. Define kernel registry.
2. Register kernels.
"""
from ....accelerator.helper import get_current_accelerator
from .base import BaseKernel
__all__ = ["Registry", "register_kernel"]
class Registry:
"""Registry for managing kernel implementations.
Storage structure: ``{ "kernel_id": Class }``
"""
_kernels: dict[str, type[BaseKernel]] = {}
@classmethod
def register(cls, kernel_cls: type[BaseKernel]) -> type[BaseKernel] | None:
"""Decorator to register a kernel class.
The class must inherit from :class:`BaseKernel` and specify ``_kernel_id`` and ``_device`` attributes.
Args:
kernel_cls (type[BaseKernel]): The kernel class to register.
Returns:
type[BaseKernel] | None: The registered kernel class if the device type matches the current accelerator
Raises:
TypeError: If the class does not inherit from :class:`BaseKernel`.
ValueError: If the kernel ID is missing or already registered.
"""
if not issubclass(kernel_cls, BaseKernel):
raise TypeError(f"Class {kernel_cls} must inherit from BaseKernel")
kernel_id = kernel_cls.get_kernel_id()
device = kernel_cls.get_device()
# The device type of the current accelerator does not match the device type required by the kernel, skip registration
if device != get_current_accelerator().type:
return
if not kernel_id:
raise ValueError(f"Kernel ID (_kernel_id) is needed for {kernel_cls} to register")
if kernel_id in cls._kernels:
raise ValueError(f"{kernel_id} already registered! The registered kernel is {cls._kernels[kernel_id]}")
cls._kernels[kernel_id] = kernel_cls
return kernel_cls
@classmethod
def get(cls, kernel_id: str) -> type[BaseKernel] | None:
"""Retrieves a registered kernel implementation by its ID.
Args:
kernel_id (str): The ID of the kernel to retrieve.
Returns:
type[BaseKernel] | None: The kernel class if found, else ``None``.
"""
return cls._kernels.get(kernel_id)
@classmethod
def get_registered_kernels(cls) -> dict[str, type[BaseKernel]]:
"""Returns a dictionary of all registered kernels.
Returns:
dict[str, type[BaseKernel]]: Dictionary mapping kernel IDs to kernel classes.
"""
return cls._kernels
# export decorator alias
register_kernel = Registry.register