|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import re |
|
|
from functools import partial |
|
|
from typing import Optional, Union |
|
|
|
|
|
from ..modeling_flash_attention_utils import lazy_import_flash_attention |
|
|
from .flash_attention import flash_attention_forward |
|
|
|
|
|
|
|
|
try: |
|
|
from kernels import ( |
|
|
Device, |
|
|
LayerRepository, |
|
|
Mode, |
|
|
get_kernel, |
|
|
register_kernel_mapping, |
|
|
replace_kernel_forward_from_hub, |
|
|
use_kernel_forward_from_hub, |
|
|
) |
|
|
|
|
|
_kernels_available = True |
|
|
|
|
|
_KERNEL_MAPPING: dict[str, dict[Union[Device, str], LayerRepository]] = { |
|
|
"MultiScaleDeformableAttention": { |
|
|
"cuda": LayerRepository( |
|
|
repo_id="kernels-community/deformable-detr", |
|
|
layer_name="MultiScaleDeformableAttention", |
|
|
) |
|
|
}, |
|
|
"Llama4TextMoe": { |
|
|
"cuda": LayerRepository( |
|
|
|
|
|
repo_id="kernels-community/moe", |
|
|
layer_name="Llama4TextMoe", |
|
|
) |
|
|
}, |
|
|
"RMSNorm": { |
|
|
"cuda": LayerRepository( |
|
|
repo_id="kernels-community/liger_kernels", |
|
|
layer_name="LigerRMSNorm", |
|
|
|
|
|
), |
|
|
"rocm": { |
|
|
Mode.INFERENCE: LayerRepository( |
|
|
repo_id="kernels-community/liger_kernels", |
|
|
layer_name="LigerRMSNorm", |
|
|
|
|
|
) |
|
|
}, |
|
|
}, |
|
|
"MLP": { |
|
|
"cuda": LayerRepository( |
|
|
repo_id="medmekk/triton-llama-mlp", |
|
|
layer_name="TritonLlamaMLP", |
|
|
) |
|
|
}, |
|
|
"MegaBlocksMoeMLP": { |
|
|
"cuda": { |
|
|
Mode.TRAINING: LayerRepository( |
|
|
repo_id="kernels-community/megablocks", |
|
|
layer_name="MegaBlocksMoeMLP", |
|
|
), |
|
|
Mode.INFERENCE: LayerRepository( |
|
|
repo_id="kernels-community/megablocks", |
|
|
layer_name="MegaBlocksMoeMLP", |
|
|
), |
|
|
}, |
|
|
"rocm": { |
|
|
Mode.INFERENCE: LayerRepository( |
|
|
repo_id="ahadnagy/megablocks", |
|
|
layer_name="MegaBlocksMoeMLP", |
|
|
) |
|
|
}, |
|
|
}, |
|
|
"FastGELU": { |
|
|
"cuda": { |
|
|
Mode.INFERENCE | Mode.TORCH_COMPILE: LayerRepository( |
|
|
repo_id="kernels-community/activation", |
|
|
layer_name="FastGELU", |
|
|
version=">=0.0.4,<0.1.0", |
|
|
) |
|
|
} |
|
|
}, |
|
|
"QuickGELU": { |
|
|
"cuda": { |
|
|
Mode.INFERENCE | Mode.TORCH_COMPILE: LayerRepository( |
|
|
repo_id="kernels-community/activation", |
|
|
layer_name="QuickGELU", |
|
|
version=">=0.0.4,<0.1.0", |
|
|
) |
|
|
} |
|
|
}, |
|
|
"NewGELU": { |
|
|
"cuda": { |
|
|
Mode.INFERENCE | Mode.TORCH_COMPILE: LayerRepository( |
|
|
repo_id="kernels-community/activation", |
|
|
layer_name="NewGELU", |
|
|
version=">=0.0.4,<0.1.0", |
|
|
) |
|
|
} |
|
|
}, |
|
|
"SiLU": { |
|
|
"cuda": { |
|
|
Mode.INFERENCE | Mode.TORCH_COMPILE: LayerRepository( |
|
|
repo_id="kernels-community/activation", layer_name="Silu", version=">=0.1.0" |
|
|
) |
|
|
} |
|
|
}, |
|
|
"GeLU": { |
|
|
"cuda": { |
|
|
Mode.INFERENCE | Mode.TORCH_COMPILE: LayerRepository( |
|
|
repo_id="kernels-community/activation", layer_name="Gelu", version=">=0.1.0" |
|
|
) |
|
|
} |
|
|
}, |
|
|
"GeluTanh": { |
|
|
"cuda": { |
|
|
Mode.INFERENCE | Mode.TORCH_COMPILE: LayerRepository( |
|
|
repo_id="kernels-community/activation", layer_name="GeluTanh", version=">=0.1.0" |
|
|
) |
|
|
} |
|
|
}, |
|
|
} |
|
|
|
|
|
register_kernel_mapping(_KERNEL_MAPPING) |
|
|
|
|
|
except ImportError: |
|
|
_kernels_available = False |
|
|
|
|
|
|
|
|
|
|
|
def use_kernel_forward_from_hub(*args, **kwargs): |
|
|
def decorator(cls): |
|
|
return cls |
|
|
|
|
|
return decorator |
|
|
|
|
|
class LayerRepository: |
|
|
def __init__(self, *args, **kwargs): |
|
|
raise RuntimeError("LayerRepository requires `kernels` to be installed. Run `pip install kernels`.") |
|
|
|
|
|
def replace_kernel_forward_from_hub(*args, **kwargs): |
|
|
raise RuntimeError( |
|
|
"replace_kernel_forward_from_hub requires `kernels` to be installed. Run `pip install kernels`." |
|
|
) |
|
|
|
|
|
def register_kernel_mapping(*args, **kwargs): |
|
|
raise RuntimeError("register_kernel_mapping requires `kernels` to be installed. Run `pip install kernels`.") |
|
|
|
|
|
|
|
|
def is_kernel(attn_implementation: Optional[str]) -> bool: |
|
|
"""Check whether `attn_implementation` matches a kernel pattern from the hub.""" |
|
|
return ( |
|
|
attn_implementation is not None |
|
|
and re.search(r"^[^/:]+/[^/:]+(?:@[^/:]+)?(?::[^/:]+)?$", attn_implementation) is not None |
|
|
) |
|
|
|
|
|
|
|
|
def load_and_register_kernel(attn_implementation: str) -> None: |
|
|
"""Load and register the kernel associated to `attn_implementation`.""" |
|
|
if not is_kernel(attn_implementation): |
|
|
return |
|
|
if not _kernels_available: |
|
|
raise ImportError( |
|
|
"`kernels` is either not installed or uses an incompatible version. " |
|
|
"Please install the latest version with `pip install -U kernels`." |
|
|
) |
|
|
|
|
|
|
|
|
from ..masking_utils import ALL_MASK_ATTENTION_FUNCTIONS |
|
|
from ..modeling_utils import ALL_ATTENTION_FUNCTIONS |
|
|
|
|
|
attention_wrapper = None |
|
|
|
|
|
actual_attn_name = attn_implementation |
|
|
if "|" in attn_implementation: |
|
|
attention_wrapper, actual_attn_name = attn_implementation.split("|") |
|
|
|
|
|
attention_wrapper = ALL_ATTENTION_FUNCTIONS.get(attention_wrapper) |
|
|
|
|
|
if ":" in actual_attn_name: |
|
|
repo_id, kernel_name = actual_attn_name.split(":") |
|
|
kernel_name = kernel_name.strip() |
|
|
else: |
|
|
repo_id = actual_attn_name |
|
|
kernel_name = None |
|
|
repo_id = repo_id.strip() |
|
|
|
|
|
repo_id, _, rev = repo_id.partition("@") |
|
|
repo_id = repo_id.strip() |
|
|
rev = rev.strip() if rev else None |
|
|
|
|
|
|
|
|
try: |
|
|
kernel = get_kernel(repo_id, revision=rev) |
|
|
except Exception as e: |
|
|
raise ValueError(f"An error occurred while trying to load from '{repo_id}': {e}.") |
|
|
|
|
|
if hasattr(kernel, "flash_attn_varlen_func"): |
|
|
if attention_wrapper is None: |
|
|
attention_wrapper = flash_attention_forward |
|
|
kernel_function = partial(attention_wrapper, implementation=kernel) |
|
|
lazy_import_flash_attention(kernel, force_import=True) |
|
|
elif kernel_name is not None: |
|
|
kernel_function = getattr(kernel, kernel_name) |
|
|
|
|
|
ALL_ATTENTION_FUNCTIONS.register(attn_implementation, kernel_function) |
|
|
ALL_MASK_ATTENTION_FUNCTIONS.register(attn_implementation, ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"]) |
|
|
|
|
|
|
|
|
__all__ = [ |
|
|
"LayerRepository", |
|
|
"use_kernel_forward_from_hub", |
|
|
"register_kernel_mapping", |
|
|
"replace_kernel_forward_from_hub", |
|
|
] |
|
|
|