dongxx1104's picture
Upload folder using huggingface_hub
db704cb verified
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The definition of base kernel class.
Init Phase:
1. Define base kernel class.
2. Define abstract methods.
"""
from abc import ABC, abstractmethod
from typing import Any
from ....accelerator.helper import DeviceType, get_current_accelerator
from ....utils.types import HFModel
class BaseKernel(ABC):
r"""Base class for all kernel implementations.
Subclasses must implement the abstract methods and define the required class attributes.
"""
_kernel_id: Any = "" # kernel ID, any hashable value to identify a kernel implementation
_device: DeviceType = DeviceType.CPU # "cuda", "npu", "cpu", etc.
@classmethod
def get_kernel_id(cls) -> str:
"""Returns the unique identifier for the kernel."""
return cls._kernel_id
@classmethod
def get_device(cls) -> str:
"""Returns the device type associated with the kernel (e.g., "cuda", "npu", "cpu")."""
return cls._device
@classmethod
def check_deps(cls) -> bool:
"""Checks if the required dependencies for the kernel are available.
Returns:
bool: ``True`` if dependencies are met, ``False`` otherwise.
.. note::
In explicit mode, if a user specifies an implementation but this check fails,
it should raise an error instead of silently switching.
Kernels can override this method to implement custom dependency checks.
"""
if cls._device != get_current_accelerator().type:
return False
return True
@classmethod
@abstractmethod
def apply(cls, **kwargs) -> HFModel:
"""Applies the kernel optimization to the model.
Args:
**kwargs: Arbitrary keyword arguments, usually containing the model instance and the kernel configuration.
Returns:
HFModel: The model with the kernel applied.
Raises:
RuntimeError: If the kernel dependencies are not met.
NotImplementedError: If the method is not implemented by the subclass.
Example:
>>> from llamafactory.v1.plugins.model_plugins.kernels.interface import apply_kernel
>>> model = HFModel(config=config)
>>> model = apply_kernel(model=model, kernel_id="npu_fused_moe")
"""
if not cls.check_deps():
raise RuntimeError(f"{cls.__name__} is not available but {cls.__name__} kernel was called.")
raise NotImplementedError