| |
| |
| |
| |
| @@ -370,7 +370,7 @@ def lazy_load_kernel(kernel_name: str, mapping: dict[str, ModuleType | None] = _ |
| if callable(is_kernel_available) and is_kernel_available(): |
| # Try to import the module "{kernel_name}" from parent package level |
| try: |
| - module = importlib.import_module(f"{kernel_name}") |
| + module = importlib.import_module(f"{new_kernel_name}") |
| mapping[kernel_name] = module |
| return module |
| except Exception: |
| |
| |
| |
| |
| @@ -36,6 +36,7 @@ |
| from ...cache_utils import Cache |
| from ...generation import GenerationMixin |
| from ...integrations import use_kernel_forward_from_hub, use_kernelized_func |
| +from ...integrations.hub_kernels import lazy_load_kernel |
| from ...modeling_attn_mask_utils import AttentionMaskConverter |
| from ...modeling_layers import GradientCheckpointingLayer |
| from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast |
| @@ -44,22 +45,9 @@ |
| from ...processing_utils import Unpack |
| from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging |
| from ...utils.generic import maybe_autocast |
| -from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available |
| from .configuration_bamba import BambaConfig |
| |
| |
| -if is_mamba_2_ssm_available(): |
| - from mamba_ssm.ops.triton.selective_state_update import selective_state_update |
| - from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined |
| -else: |
| - selective_state_update = None |
| - |
| -if is_causal_conv1d_available(): |
| - from causal_conv1d import causal_conv1d_fn, causal_conv1d_update |
| -else: |
| - causal_conv1d_update, causal_conv1d_fn = None, None |
| - |
| - |
| logger = logging.get_logger(__name__) |
| |
| |
| @@ -501,9 +489,6 @@ def apply_mask_to_padding_states(hidden_states, attention_mask): |
| return hidden_states |
| |
| |
| -is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) |
| - |
| - |
| # Adapted from transformers.models.mamba2.modeling_mamba2.Mamba2Mixer |
| class BambaMixer(nn.Module): |
| """ |
| @@ -575,6 +560,20 @@ def __init__(self, config: BambaConfig, layer_idx: int): |
| |
| self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias) |
| |
| + global causal_conv1d_update, causal_conv1d_fn |
| + causal_conv1d = lazy_load_kernel("causal-conv1d") |
| + causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None) |
| + causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None) |
| + |
| + global selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined |
| + mamba_ssm = lazy_load_kernel("mamba-ssm") |
| + selective_state_update = getattr(mamba_ssm, "selective_state_update", None) |
| + mamba_chunk_scan_combined = getattr(mamba_ssm, "mamba_chunk_scan_combined", None) |
| + mamba_split_conv1d_scan_combined = getattr(mamba_ssm, "mamba_split_conv1d_scan_combined", None) |
| + |
| + global is_fast_path_available |
| + is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) |
| + |
| if not is_fast_path_available: |
| logger.warning_once( |
| "The fast path is not available because one of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`" |
| |
| |
| |
| |
| @@ -43,6 +43,7 @@ |
| ) |
| |
| from ... import initialization as init |
| +from ...integrations.hub_kernels import lazy_load_kernel |
| from ...modeling_attn_mask_utils import AttentionMaskConverter |
| from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast |
| from ...modeling_utils import PreTrainedModel |
| @@ -52,24 +53,9 @@ |
| can_return_tuple, |
| logging, |
| ) |
| -from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available |
| from .configuration_bamba import BambaConfig |
| |
| |
| -if is_mamba_2_ssm_available(): |
| - from mamba_ssm.ops.triton.selective_state_update import selective_state_update |
| - from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined |
| -else: |
| - selective_state_update = None |
| - |
| -if is_causal_conv1d_available(): |
| - from causal_conv1d import causal_conv1d_fn, causal_conv1d_update |
| -else: |
| - causal_conv1d_update, causal_conv1d_fn = None, None |
| - |
| -is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) |
| - |
| - |
| logger = logging.get_logger(__name__) |
| |
| |
| @@ -276,6 +262,20 @@ def __init__(self, config: BambaConfig, layer_idx: int): |
| |
| self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias) |
| |
| + global causal_conv1d_update, causal_conv1d_fn |
| + causal_conv1d = lazy_load_kernel("causal-conv1d") |
| + causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None) |
| + causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None) |
| + |
| + global selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined |
| + mamba_ssm = lazy_load_kernel("mamba-ssm") |
| + selective_state_update = getattr(mamba_ssm, "selective_state_update", None) |
| + mamba_chunk_scan_combined = getattr(mamba_ssm, "mamba_chunk_scan_combined", None) |
| + mamba_split_conv1d_scan_combined = getattr(mamba_ssm, "mamba_split_conv1d_scan_combined", None) |
| + |
| + global is_fast_path_available |
| + is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) |
| + |
| if not is_fast_path_available: |
| logger.warning_once( |
| "The fast path is not available because one of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`" |
| |
| |
| |
| |
| @@ -32,6 +32,7 @@ |
| from ...cache_utils import Cache |
| from ...generation import GenerationMixin |
| from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func |
| +from ...integrations.hub_kernels import lazy_load_kernel |
| from ...masking_utils import create_causal_mask |
| from ...modeling_layers import GradientCheckpointingLayer |
| from ...modeling_outputs import BaseModelOutputWithPast, MoeCausalLMOutputWithPast, MoeModelOutputWithPast |
| @@ -40,22 +41,9 @@ |
| from ...processing_utils import Unpack |
| from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging |
| from ...utils.generic import check_model_inputs, maybe_autocast |
| -from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available |
| from .configuration_granitemoehybrid import GraniteMoeHybridConfig |
| |
| |
| -if is_mamba_2_ssm_available(): |
| - from mamba_ssm.ops.triton.selective_state_update import selective_state_update |
| - from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined |
| -else: |
| - selective_state_update = None |
| - |
| -if is_causal_conv1d_available(): |
| - from causal_conv1d import causal_conv1d_fn, causal_conv1d_update |
| -else: |
| - causal_conv1d_update, causal_conv1d_fn = None, None |
| - |
| - |
| logger = logging.get_logger(__name__) |
| |
| |
| @@ -371,9 +359,6 @@ def apply_mask_to_padding_states(hidden_states, attention_mask): |
| return hidden_states |
| |
| |
| -is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) |
| - |
| - |
| # Adapted from transformers.models.mamba2.modeling_mamba2.Mamba2Mixer |
| class GraniteMoeHybridMambaLayer(nn.Module): |
| """ |
| @@ -445,6 +430,20 @@ def __init__(self, config: GraniteMoeHybridConfig, layer_idx: int): |
| |
| self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias) |
| |
| + global causal_conv1d_update, causal_conv1d_fn |
| + causal_conv1d = lazy_load_kernel("causal-conv1d") |
| + causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None) |
| + causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None) |
| + |
| + global selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined |
| + mamba_ssm = lazy_load_kernel("mamba-ssm") |
| + selective_state_update = getattr(mamba_ssm, "selective_state_update", None) |
| + mamba_chunk_scan_combined = getattr(mamba_ssm, "mamba_chunk_scan_combined", None) |
| + mamba_split_conv1d_scan_combined = getattr(mamba_ssm, "mamba_split_conv1d_scan_combined", None) |
| + |
| + global is_fast_path_available |
| + is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) |
| + |
| if not is_fast_path_available: |
| logger.warning_once( |
| "The fast path is not available because one of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`" |
| |
| |
| |
| |
| @@ -33,6 +33,7 @@ |
| from ...activations import ACT2FN |
| from ...generation import GenerationMixin |
| from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func |
| +from ...integrations.hub_kernels import lazy_load_kernel |
| from ...masking_utils import create_causal_mask |
| from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer |
| from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast |
| @@ -40,22 +41,9 @@ |
| from ...processing_utils import Unpack |
| from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging |
| from ...utils.generic import OutputRecorder, check_model_inputs |
| -from ...utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available |
| from .configuration_jamba import JambaConfig |
| |
| |
| -if is_mamba_ssm_available(): |
| - from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn |
| - from mamba_ssm.ops.triton.selective_state_update import selective_state_update |
| -else: |
| - selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None |
| - |
| -if is_causal_conv1d_available(): |
| - from causal_conv1d import causal_conv1d_fn, causal_conv1d_update |
| -else: |
| - causal_conv1d_update, causal_conv1d_fn = None, None |
| - |
| - |
| logger = logging.get_logger(__name__) |
| |
| |
| @@ -306,11 +294,6 @@ def forward( |
| return attn_output, attn_weights |
| |
| |
| -is_fast_path_available = all( |
| - (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) |
| -) |
| - |
| - |
| class JambaMambaMixer(nn.Module): |
| """ |
| Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`. |
| @@ -364,6 +347,22 @@ def __init__(self, config: JambaConfig, layer_idx): |
| self.b_layernorm = JambaRMSNorm(self.ssm_state_size, eps=config.rms_norm_eps) |
| self.c_layernorm = JambaRMSNorm(self.ssm_state_size, eps=config.rms_norm_eps) |
| |
| + global causal_conv1d_update, causal_conv1d_fn |
| + causal_conv1d = lazy_load_kernel("causal-conv1d") |
| + causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None) |
| + causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None) |
| + |
| + global selective_state_update, mamba_inner_fn, selective_scan_fn |
| + mamba_ssm = lazy_load_kernel("mamba-ssm") |
| + selective_state_update = getattr(mamba_ssm, "selective_state_update", None) |
| + mamba_inner_fn = getattr(mamba_ssm, "mamba_inner_fn", None) |
| + selective_scan_fn = getattr(mamba_ssm, "selective_scan_fn", None) |
| + |
| + global is_fast_path_available |
| + is_fast_path_available = all( |
| + (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) |
| + ) |
| + |
| if not is_fast_path_available: |
| logger.warning_once( |
| "The fast path is not available because on of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`" |
| |
| |
| |
| |
| @@ -25,6 +25,7 @@ |
| |
| from ... import initialization as init |
| from ...activations import ACT2FN |
| +from ...integrations.hub_kernels import lazy_load_kernel |
| from ...masking_utils import create_causal_mask |
| from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer |
| from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast |
| @@ -32,29 +33,12 @@ |
| from ...processing_utils import Unpack |
| from ...utils import TransformersKwargs, auto_docstring, logging |
| from ...utils.generic import OutputRecorder, check_model_inputs |
| -from ...utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available |
| from ..llama.modeling_llama import LlamaAttention, LlamaRMSNorm, eager_attention_forward |
| from ..mistral.modeling_mistral import MistralMLP |
| from ..mixtral.modeling_mixtral import MixtralExperts, MixtralForCausalLM |
| from .configuration_jamba import JambaConfig |
| |
| |
| -if is_mamba_ssm_available(): |
| - from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn |
| - from mamba_ssm.ops.triton.selective_state_update import selective_state_update |
| -else: |
| - selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None |
| - |
| -if is_causal_conv1d_available(): |
| - from causal_conv1d import causal_conv1d_fn, causal_conv1d_update |
| -else: |
| - causal_conv1d_update, causal_conv1d_fn = None, None |
| - |
| -is_fast_path_available = all( |
| - (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) |
| -) |
| - |
| - |
| logger = logging.get_logger(__name__) |
| |
| |
| @@ -258,6 +242,22 @@ def __init__(self, config: JambaConfig, layer_idx): |
| self.b_layernorm = JambaRMSNorm(self.ssm_state_size, eps=config.rms_norm_eps) |
| self.c_layernorm = JambaRMSNorm(self.ssm_state_size, eps=config.rms_norm_eps) |
| |
| + global causal_conv1d_update, causal_conv1d_fn |
| + causal_conv1d = lazy_load_kernel("causal-conv1d") |
| + causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None) |
| + causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None) |
| + |
| + global selective_state_update, mamba_inner_fn, selective_scan_fn |
| + mamba_ssm = lazy_load_kernel("mamba-ssm") |
| + selective_state_update = getattr(mamba_ssm, "selective_state_update", None) |
| + mamba_inner_fn = getattr(mamba_ssm, "mamba_inner_fn", None) |
| + selective_scan_fn = getattr(mamba_ssm, "selective_scan_fn", None) |
| + |
| + global is_fast_path_available |
| + is_fast_path_available = all( |
| + (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) |
| + ) |
| + |
| if not is_fast_path_available: |
| logger.warning_once( |
| "The fast path is not available because on of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`" |
| |
| |
| |
| |
| @@ -24,6 +24,7 @@ |
| from ... import initialization as init |
| from ...activations import ACT2FN |
| from ...generation import GenerationMixin |
| +from ...integrations.hub_kernels import lazy_load_kernel |
| from ...modeling_layers import GradientCheckpointingLayer |
| from ...modeling_utils import PreTrainedModel |
| from ...utils import ( |
| @@ -31,35 +32,12 @@ |
| auto_docstring, |
| logging, |
| ) |
| -from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available |
| from .configuration_mamba2 import Mamba2Config |
| |
| |
| logger = logging.get_logger(__name__) |
| |
| |
| -if is_mamba_2_ssm_available(): |
| - from mamba_ssm.ops.triton.selective_state_update import selective_state_update |
| - from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined |
| -else: |
| - mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined, selective_state_update = None, None, None |
| - |
| -if is_causal_conv1d_available(): |
| - from causal_conv1d import causal_conv1d_fn, causal_conv1d_update |
| -else: |
| - causal_conv1d_update, causal_conv1d_fn = None, None |
| - |
| -is_fast_path_available = all( |
| - ( |
| - selective_state_update, |
| - mamba_chunk_scan_combined, |
| - mamba_split_conv1d_scan_combined, |
| - causal_conv1d_fn, |
| - causal_conv1d_update, |
| - ) |
| -) |
| - |
| - |
| # Helper methods for segment sum computation |
| |
| |
| @@ -286,6 +264,28 @@ def __init__(self, config: Mamba2Config, layer_idx: int): |
| self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias) |
| self.use_bias = config.use_bias |
| |
| + global causal_conv1d_update, causal_conv1d_fn |
| + causal_conv1d = lazy_load_kernel("causal-conv1d") |
| + causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None) |
| + causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None) |
| + |
| + global selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined |
| + mamba_ssm = lazy_load_kernel("mamba-ssm") |
| + selective_state_update = getattr(mamba_ssm, "selective_state_update", None) |
| + mamba_chunk_scan_combined = getattr(mamba_ssm, "mamba_chunk_scan_combined", None) |
| + mamba_split_conv1d_scan_combined = getattr(mamba_ssm, "mamba_split_conv1d_scan_combined", None) |
| + |
| + global is_fast_path_available |
| + is_fast_path_available = all( |
| + ( |
| + selective_state_update, |
| + mamba_chunk_scan_combined, |
| + mamba_split_conv1d_scan_combined, |
| + causal_conv1d_fn, |
| + causal_conv1d_update, |
| + ) |
| + ) |
| + |
| if not is_fast_path_available: |
| logger.warning_once( |
| "The fast path is not available because one of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`" |
| |
| |
| |
| |
| @@ -45,10 +45,7 @@ |
| from ...processing_utils import Unpack |
| from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging |
| from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast |
| -from ...utils.import_utils import ( |
| - is_causal_conv1d_available, |
| - is_flash_linear_attention_available, |
| -) |
| +from ...utils.import_utils import is_causal_conv1d_available, is_flash_linear_attention_available |
| from .configuration_qwen3_next import Qwen3NextConfig |
| |
| |
|
|