| |
| |
| |
| |
| @@ -72,6 +72,7 @@ |
| "register_kernel_mapping", |
| "replace_kernel_forward_from_hub", |
| "use_kernel_forward_from_hub", |
| + "use_kernel_func_from_hub", |
| ], |
| "integration_utils": [ |
| "INTEGRATION_TO_CALLBACK", |
| @@ -212,6 +213,7 @@ |
| register_kernel_mapping, |
| replace_kernel_forward_from_hub, |
| use_kernel_forward_from_hub, |
| + use_kernel_func_from_hub, |
| ) |
| from .integration_utils import ( |
| INTEGRATION_TO_CALLBACK, |
| |
| |
| |
| |
| @@ -34,6 +34,23 @@ |
| register_kernel_mapping, |
| replace_kernel_forward_from_hub, |
| ) |
| + from kernels import ( |
| + use_kernel_forward_from_hub as _kernels_use_kernel_forward_from_hub, |
| + ) |
| + |
| + # Try to import FuncRepository, fallback if not available |
| + try: |
| + from kernels import FuncRepository |
| + except ImportError: |
| + FuncRepository = None |
| + |
| + # Try to import use_kernel_func_from_hub, fallback if not available |
| + try: |
| + from kernels import use_kernel_func_from_hub as _kernels_use_kernel_func_from_hub |
| + |
| + _has_use_kernel_func_from_hub = True |
| + except ImportError: |
| + _has_use_kernel_func_from_hub = False |
| |
| _TRANSFORMERS_USE_HUB_KERNELS = os.environ.get("USE_HUB_KERNELS", "YES").upper() |
| _kernels_available = True |
| @@ -41,8 +58,6 @@ |
| |
| def use_kernel_forward_from_hub(layer_name: str): |
| if _kernels_enabled: |
| - from kernels import use_kernel_forward_from_hub as _kernels_use_kernel_forward_from_hub |
| - |
| return _kernels_use_kernel_forward_from_hub(layer_name) |
| else: |
| logger.warning_once( |
| @@ -50,6 +65,21 @@ def use_kernel_forward_from_hub(layer_name: str): |
| ) |
| return lambda cls: cls |
| |
| + def use_kernel_func_from_hub(func_name: str): |
| + if _kernels_enabled and _has_use_kernel_func_from_hub: |
| + return _kernels_use_kernel_func_from_hub(func_name) |
| + else: |
| + if not _has_use_kernel_func_from_hub: |
| + logger.warning_once( |
| + "use_kernel_func_from_hub is not available in the installed kernels version. " |
| + "Please upgrade kernels to use this feature." |
| + ) |
| + else: |
| + logger.warning_once( |
| + f"kernels hub usage is disabled through the environment USE_HUB_KERNELS={_TRANSFORMERS_USE_HUB_KERNELS}" |
| + ) |
| + return lambda func: func |
| + |
| _KERNEL_MAPPING: dict[str, dict[Device | str, LayerRepository]] = { |
| "MultiScaleDeformableAttention": { |
| "cuda": LayerRepository( |
| @@ -164,6 +194,16 @@ def use_kernel_forward_from_hub(layer_name: str): |
| }, |
| } |
| |
| + # Add function kernel mappings if FuncRepository is available |
| + if FuncRepository is not None: |
| + _KERNEL_MAPPING["rotary_pos_emb"] = { |
| + "xpu": { |
| + Mode.INFERENCE: FuncRepository( |
| + repo_id="kernels-community/rotary", func_name="apply_rotary_transformers" |
| + ) |
| + } |
| + } |
| + |
| def has_key(d, key): |
| return key in d or any(isinstance(v, dict) and has_key(v, key) for v in d.values()) |
| |
| @@ -189,6 +229,12 @@ def decorator(cls): |
| |
| return decorator |
| |
| + def use_kernel_func_from_hub(*args, **kwargs): |
| + def decorator(func): |
| + return func |
| + |
| + return decorator |
| + |
| class LayerRepository: |
| def __init__(self, *args, **kwargs): |
| raise RuntimeError("LayerRepository requires `kernels` to be installed. Run `pip install kernels`.") |
| @@ -201,6 +247,11 @@ def replace_kernel_forward_from_hub(*args, **kwargs): |
| def register_kernel_mapping(*args, **kwargs): |
| raise RuntimeError("register_kernel_mapping requires `kernels` to be installed. Run `pip install kernels`.") |
| |
| + def register_kernel_mapping_transformers(*args, **kwargs): |
| + raise RuntimeError( |
| + "register_kernel_mapping_transformers requires `kernels` to be installed. Run `pip install kernels`." |
| + ) |
| + |
| |
| _HUB_KERNEL_MAPPING: dict[str, dict[str, str]] = { |
| "causal-conv1d": {"repo_id": "kernels-community/causal-conv1d"}, |
| @@ -319,6 +370,7 @@ def lazy_load_kernel(kernel_name: str, mapping: dict[str, ModuleType | None] = _ |
| __all__ = [ |
| "LayerRepository", |
| "use_kernel_forward_from_hub", |
| + "use_kernel_func_from_hub", |
| "register_kernel_mapping", |
| "register_kernel_mapping_transformers", |
| "replace_kernel_forward_from_hub", |
| |
| |
| |
| |
| @@ -28,7 +28,7 @@ |
| from ...activations import ACT2FN |
| from ...cache_utils import Cache, DynamicCache |
| from ...generation import GenerationMixin |
| -from ...integrations import use_kernel_forward_from_hub |
| +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub |
| from ...masking_utils import create_causal_mask |
| from ...modeling_layers import GenericForTokenClassification, GradientCheckpointingLayer |
| from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast |
| @@ -147,6 +147,7 @@ def rotate_half(x): |
| return torch.cat((-x2, x1), dim=-1) |
| |
| |
| +@use_kernel_func_from_hub("rotary_pos_emb") |
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| """Applies Rotary Position Embedding to the query and key tensors. |
| |
| @@ -237,6 +238,7 @@ def __init__(self, config: ApertusConfig, layer_idx: Optional[int] = None): |
| self.o_proj = nn.Linear( |
| config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias |
| ) |
| + self.rotary_fn = apply_rotary_pos_emb |
| self.q_norm = ApertusRMSNorm(self.head_dim, config.rms_norm_eps) |
| self.k_norm = ApertusRMSNorm(self.head_dim, config.rms_norm_eps) |
| |
| |
| |
| |
| |
| @@ -30,7 +30,7 @@ |
| from ...activations import ACT2FN |
| from ...cache_utils import Cache, DynamicCache |
| from ...generation import GenerationMixin |
| -from ...integrations import use_kernel_forward_from_hub |
| +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub |
| from ...masking_utils import create_causal_mask |
| from ...modeling_layers import ( |
| GenericForQuestionAnswering, |
| @@ -154,6 +154,7 @@ def rotate_half(x): |
| return torch.cat((-x2, x1), dim=-1) |
| |
| |
| +@use_kernel_func_from_hub("rotary_pos_emb") |
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| """Applies Rotary Position Embedding to the query and key tensors. |
| |
| @@ -244,6 +245,7 @@ def __init__(self, config: ArceeConfig, layer_idx: int): |
| self.o_proj = nn.Linear( |
| config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias |
| ) |
| + self.rotary_fn = apply_rotary_pos_emb |
| |
| def forward( |
| self, |
| |
| |
| |
| |
| @@ -29,7 +29,7 @@ |
| from ...activations import ACT2FN |
| from ...cache_utils import Cache, DynamicCache |
| from ...generation import GenerationMixin |
| -from ...integrations import use_kernel_forward_from_hub |
| +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub |
| from ...masking_utils import create_causal_mask |
| from ...modeling_flash_attention_utils import FlashAttentionKwargs |
| from ...modeling_layers import GradientCheckpointingLayer |
| @@ -378,6 +378,7 @@ def rotate_half(x): |
| return torch.cat((-x2, x1), dim=-1) |
| |
| |
| +@use_kernel_func_from_hub("rotary_pos_emb") |
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| """Applies Rotary Position Embedding to the query and key tensors. |
| |
| @@ -468,6 +469,7 @@ def __init__(self, config: AriaTextConfig, layer_idx: int): |
| self.o_proj = nn.Linear( |
| config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias |
| ) |
| + self.rotary_fn = apply_rotary_pos_emb |
| |
| def forward( |
| self, |
| |
| |
| |
| |
| @@ -370,6 +370,7 @@ def __init__(self, config: BambaConfig, layer_idx: int): |
| self.o_proj = nn.Linear( |
| config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias |
| ) |
| + self.rotary_fn = apply_rotary_pos_emb |
| |
| def forward( |
| self, |
| |
| |
| |
| |
| @@ -27,7 +27,7 @@ |
| from ...activations import ACT2FN |
| from ...cache_utils import Cache, DynamicCache |
| from ...generation import GenerationMixin |
| -from ...integrations import use_kernel_forward_from_hub |
| +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub |
| from ...masking_utils import create_causal_mask |
| from ...modeling_flash_attention_utils import FlashAttentionKwargs |
| from ...modeling_layers import GradientCheckpointingLayer |
| @@ -85,6 +85,7 @@ def rotate_half(x): |
| return torch.cat((-x2, x1), dim=-1) |
| |
| |
| +@use_kernel_func_from_hub("rotary_pos_emb") |
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| """Applies Rotary Position Embedding to the query and key tensors. |
| |
| @@ -175,6 +176,7 @@ def __init__(self, config: BitNetConfig, layer_idx: int): |
| self.o_proj = nn.Linear( |
| config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias |
| ) |
| + self.rotary_fn = apply_rotary_pos_emb |
| self.attn_sub_norm = BitNetRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| |
| def forward( |
| |
| |
| |
| |
| @@ -247,6 +247,7 @@ def __init__(self, config: CohereConfig, layer_idx: Optional[int] = None): |
| self.o_proj = nn.Linear( |
| config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias |
| ) |
| + self.rotary_fn = apply_rotary_pos_emb |
| self.use_qk_norm = config.use_qk_norm |
| if self.use_qk_norm: |
| # When sharding the model using Tensor Parallelism, need to be careful to use n_local_heads |
| |
| |
| |
| |
| @@ -32,7 +32,7 @@ |
| from ...activations import ACT2FN |
| from ...cache_utils import Cache, DynamicCache |
| from ...generation import GenerationMixin |
| -from ...integrations import use_kernel_forward_from_hub |
| +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub |
| from ...masking_utils import create_causal_mask |
| from ...modeling_layers import GradientCheckpointingLayer |
| from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast |
| @@ -206,6 +206,7 @@ def rotate_half(x): |
| return torch.cat((-x2, x1), dim=-1) |
| |
| |
| +@use_kernel_func_from_hub("rotary_pos_emb") |
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| """Applies Rotary Position Embedding to the query and key tensors. |
| |
| @@ -296,6 +297,7 @@ def __init__(self, config: CsmConfig, layer_idx: int): |
| self.o_proj = nn.Linear( |
| config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias |
| ) |
| + self.rotary_fn = apply_rotary_pos_emb |
| |
| def forward( |
| self, |
| |
| |
| |
| |
| @@ -28,7 +28,7 @@ |
| from ...activations import ACT2FN |
| from ...cache_utils import Cache, DynamicCache |
| from ...generation import GenerationMixin |
| -from ...integrations import use_kernel_forward_from_hub |
| +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub |
| from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask |
| from ...modeling_flash_attention_utils import FlashAttentionKwargs |
| from ...modeling_layers import GradientCheckpointingLayer |
| @@ -113,6 +113,7 @@ def rotate_half(x): |
| return torch.cat((-x2, x1), dim=-1) |
| |
| |
| +@use_kernel_func_from_hub("rotary_pos_emb") |
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| """Applies Rotary Position Embedding to the query and key tensors. |
| |
| @@ -195,6 +196,7 @@ def __init__(self, config: CwmConfig, layer_idx: int): |
| self.k_proj = torch.nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) |
| self.v_proj = torch.nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) |
| self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) |
| + self.rotary_fn = apply_rotary_pos_emb |
| self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None |
| |
| def forward( |
| |
| |
| |
| |
| @@ -29,6 +29,7 @@ |
| from ...activations import ACT2FN |
| from ...cache_utils import Cache, DynamicCache |
| from ...generation import GenerationMixin |
| +from ...integrations import use_kernel_func_from_hub |
| from ...masking_utils import create_causal_mask |
| from ...modeling_layers import GradientCheckpointingLayer |
| from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast |
| @@ -112,6 +113,7 @@ def rotate_half(x): |
| return torch.cat((-x2, x1), dim=-1) |
| |
| |
| +@use_kernel_func_from_hub("rotary_pos_emb") |
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| """Applies Rotary Position Embedding to the query and key tensors. |
| |
| |
| |
| |
| |
| @@ -16,7 +16,7 @@ |
| from ...activations import ACT2FN |
| from ...cache_utils import Cache, DynamicCache |
| from ...generation import GenerationMixin |
| -from ...integrations import use_kernel_forward_from_hub |
| +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub |
| from ...masking_utils import create_causal_mask |
| from ...modeling_flash_attention_utils import FlashAttentionKwargs |
| from ...modeling_layers import ( |
| @@ -253,6 +253,7 @@ def rotate_half(x): |
| return torch.cat((-x2, x1), dim=-1) |
| |
| |
| +@use_kernel_func_from_hub("rotary_pos_emb") |
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| """Applies Rotary Position Embedding to the query and key tensors. |
| |
| |
| |
| |
| |
| @@ -27,7 +27,7 @@ |
| |
| from ...activations import ACT2FN |
| from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache |
| -from ...integrations import use_kernel_forward_from_hub |
| +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub |
| from ...masking_utils import create_bidirectional_mask, create_causal_mask |
| from ...modeling_flash_attention_utils import FlashAttentionKwargs |
| from ...modeling_layers import GradientCheckpointingLayer |
| @@ -200,6 +200,7 @@ def rotate_half(x): |
| return torch.cat((-x2, x1), dim=-1) |
| |
| |
| +@use_kernel_func_from_hub("rotary_pos_emb") |
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| """Applies Rotary Position Embedding to the query and key tensors. |
| |
| |
| |
| |
| |
| @@ -32,7 +32,7 @@ |
| from ...activations import ACT2FN |
| from ...cache_utils import Cache, DynamicCache, StaticCache |
| from ...generation import GenerationMixin |
| -from ...integrations import use_kernel_forward_from_hub |
| +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub |
| from ...masking_utils import create_causal_mask |
| from ...modeling_flash_attention_utils import _flash_attention_forward, flash_attn_supports_top_left_mask |
| from ...modeling_layers import ( |
| @@ -141,6 +141,7 @@ def rotate_half(x): |
| return torch.cat((-x2, x1), dim=-1) |
| |
| |
| +@use_kernel_func_from_hub("rotary_pos_emb") |
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| """Applies Rotary Position Embedding to the query and key tensors. |
| |
| |
| |
| |
| |
| @@ -33,7 +33,7 @@ |
| from ...activations import ACT2FN |
| from ...cache_utils import Cache, DynamicCache |
| from ...generation import GenerationMixin |
| -from ...integrations import use_kernel_forward_from_hub |
| +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub |
| from ...integrations.flex_attention import compile_friendly_flex_attention |
| from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask |
| from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer |
| @@ -143,6 +143,7 @@ def rotate_half(x): |
| return torch.cat((-x2, x1), dim=-1) |
| |
| |
| +@use_kernel_func_from_hub("rotary_pos_emb") |
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| """Applies Rotary Position Embedding to the query and key tensors. |
| |
| |
| |
| |
| |
| @@ -29,7 +29,7 @@ |
| from ...activations import ACT2FN |
| from ...cache_utils import Cache, DynamicCache |
| from ...generation import GenerationMixin |
| -from ...integrations import use_kernel_forward_from_hub |
| +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub |
| from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask |
| from ...modeling_flash_attention_utils import FlashAttentionKwargs |
| from ...modeling_layers import GradientCheckpointingLayer |
| @@ -135,6 +135,7 @@ def rotate_half(x): |
| return torch.cat((-x2, x1), dim=-1) |
| |
| |
| +@use_kernel_func_from_hub("rotary_pos_emb") |
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| """Applies Rotary Position Embedding to the query and key tensors. |
| |
| @@ -226,6 +227,7 @@ def __init__(self, config: Dots1Config, layer_idx: int): |
| self.o_proj = nn.Linear( |
| config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias |
| ) |
| + self.rotary_fn = apply_rotary_pos_emb |
| self.q_norm = Dots1RMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! |
| self.k_norm = Dots1RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape |
| self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None |
| |
| |
| |
| |
| @@ -33,7 +33,7 @@ |
| from ...activations import ACT2FN |
| from ...cache_utils import Cache, DynamicCache |
| from ...generation import GenerationMixin |
| -from ...integrations import use_kernel_forward_from_hub |
| +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub |
| from ...masking_utils import create_causal_mask |
| from ...modeling_layers import GradientCheckpointingLayer |
| from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast |
| @@ -52,6 +52,7 @@ def rotate_half(x): |
| return torch.cat((-x2, x1), dim=-1) |
| |
| |
| +@use_kernel_func_from_hub("rotary_pos_emb") |
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| """Applies Rotary Position Embedding to the query and key tensors. |
| |
| @@ -142,6 +143,7 @@ def __init__(self, config: Emu3Config, layer_idx: int): |
| self.o_proj = nn.Linear( |
| config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias |
| ) |
| + self.rotary_fn = apply_rotary_pos_emb |
| |
| def forward( |
| self, |
| |
| |
| |
| |
| @@ -221,6 +221,7 @@ def __init__(self, config: Ernie4_5Config, layer_idx: int): |
| self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) |
| self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) |
| self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.use_bias) |
| + self.rotary_fn = apply_rotary_pos_emb |
| |
| def forward( |
| self, |
| |
| |
| |
| |
| @@ -244,6 +244,7 @@ def __init__(self, config: Ernie4_5_MoeConfig, layer_idx: int): |
| self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) |
| self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) |
| self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.use_bias) |
| + self.rotary_fn = apply_rotary_pos_emb |
| |
| def forward( |
| self, |
| |
| |
| |
| |
| @@ -31,7 +31,7 @@ |
| from ...activations import ACT2FN |
| from ...cache_utils import Cache, DynamicCache |
| from ...generation import GenerationMixin |
| -from ...integrations import use_kernel_forward_from_hub |
| +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub |
| from ...masking_utils import create_bidirectional_mask, create_causal_mask |
| from ...modeling_layers import GradientCheckpointingLayer |
| from ...modeling_outputs import ( |
| @@ -1051,6 +1051,7 @@ def rotate_half(x): |
| return torch.cat((-x2, x1), dim=-1) |
| |
| |
| +@use_kernel_func_from_hub("rotary_pos_emb") |
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| """Applies Rotary Position Embedding to the query and key tensors. |
| |
| @@ -1115,6 +1116,7 @@ def __init__(self, config: EvollaConfig, layer_idx: int): |
| self.o_proj = nn.Linear( |
| config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias |
| ) |
| + self.rotary_fn = apply_rotary_pos_emb |
| |
| def forward( |
| self, |
| |
| |
| |
| |
| @@ -31,7 +31,7 @@ |
| from ...activations import ACT2FN |
| from ...cache_utils import Cache, DynamicCache |
| from ...generation import GenerationMixin |
| -from ...integrations import use_kernel_forward_from_hub |
| +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub |
| from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask |
| from ...modeling_layers import ( |
| GenericForQuestionAnswering, |
| @@ -140,6 +140,7 @@ def rotate_half(x): |
| return torch.cat((-x2, x1), dim=-1) |
| |
| |
| +@use_kernel_func_from_hub("rotary_pos_emb") |
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| """Applies Rotary Position Embedding to the query and key tensors. |
| |
| |
| |
| |
| |
| @@ -36,7 +36,7 @@ |
| from ... import initialization as init |
| from ...cache_utils import Cache |
| from ...generation import GenerationMixin |
| -from ...integrations import use_kernel_forward_from_hub |
| +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub |
| from ...modeling_attn_mask_utils import AttentionMaskConverter |
| from ...modeling_flash_attention_utils import FlashAttentionKwargs |
| from ...modeling_layers import GradientCheckpointingLayer |
| @@ -295,6 +295,7 @@ def rotate_half(x): |
| return torch.cat((-x2, x1), dim=-1) |
| |
| |
| +@use_kernel_func_from_hub("rotary_pos_emb") |
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| """Applies Rotary Position Embedding to the query and key tensors. |
| |
| @@ -385,6 +386,7 @@ def __init__(self, config: FalconH1Config, layer_idx: int): |
| self.o_proj = nn.Linear( |
| config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias |
| ) |
| + self.rotary_fn = apply_rotary_pos_emb |
| self.key_multiplier = config.key_multiplier |
| |
| def forward( |
| |
| |
| |
| |
| @@ -241,6 +241,7 @@ def __init__(self, config: FlexOlmoConfig, layer_idx: Optional[int] = None): |
| self.o_proj = nn.Linear( |
| config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias |
| ) |
| + self.rotary_fn = apply_rotary_pos_emb |
| self.q_norm = FlexOlmoRMSNorm(config.num_attention_heads * self.head_dim, config.rms_norm_eps) |
| self.k_norm = FlexOlmoRMSNorm(config.num_key_value_heads * self.head_dim, config.rms_norm_eps) |
| |
| |
| |
| |
| |
| @@ -29,6 +29,7 @@ |
| from ...activations import ACT2FN |
| from ...cache_utils import Cache, DynamicCache |
| from ...generation import GenerationMixin |
| +from ...integrations import use_kernel_func_from_hub |
| from ...masking_utils import create_causal_mask |
| from ...modeling_layers import ( |
| GenericForSequenceClassification, |
| @@ -152,6 +153,7 @@ def rotate_half(x): |
| return torch.cat((-x2, x1), dim=-1) |
| |
| |
| +@use_kernel_func_from_hub("rotary_pos_emb") |
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| """Applies Rotary Position Embedding to the query and key tensors. |
| |
| @@ -242,6 +244,7 @@ def __init__(self, config: GemmaConfig, layer_idx: int): |
| self.o_proj = nn.Linear( |
| config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias |
| ) |
| + self.rotary_fn = apply_rotary_pos_emb |
| |
| def forward( |
| self, |
| |
| |
| |
| |
| @@ -29,6 +29,7 @@ |
| from ...activations import ACT2FN |
| from ...cache_utils import Cache, DynamicCache |
| from ...generation import GenerationMixin |
| +from ...integrations import use_kernel_func_from_hub |
| from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask |
| from ...modeling_flash_attention_utils import FlashAttentionKwargs |
| from ...modeling_layers import ( |
| @@ -156,6 +157,7 @@ def rotate_half(x): |
| return torch.cat((-x2, x1), dim=-1) |
| |
| |
| +@use_kernel_func_from_hub("rotary_pos_emb") |
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| """Applies Rotary Position Embedding to the query and key tensors. |
| |
| @@ -256,6 +258,7 @@ def __init__(self, config: Gemma2Config, layer_idx: int): |
| self.o_proj = nn.Linear( |
| config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias |
| ) |
| + self.rotary_fn = apply_rotary_pos_emb |
| self.attn_logit_softcapping = self.config.attn_logit_softcapping |
| self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None |
| |
| |
| |
| |
| |
| @@ -31,6 +31,7 @@ |
| from ...cache_utils import Cache, DynamicCache |
| from ...configuration_utils import PreTrainedConfig |
| from ...generation import GenerationMixin |
| +from ...integrations import use_kernel_func_from_hub |
| from ...masking_utils import create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask |
| from ...modeling_flash_attention_utils import FlashAttentionKwargs |
| from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer |
| @@ -230,6 +231,7 @@ def rotate_half(x): |
| return torch.cat((-x2, x1), dim=-1) |
| |
| |
| +@use_kernel_func_from_hub("rotary_pos_emb") |
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| """Applies Rotary Position Embedding to the query and key tensors. |
| |
| @@ -330,6 +332,7 @@ def __init__(self, config: Gemma3TextConfig, layer_idx: int): |
| self.o_proj = nn.Linear( |
| config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias |
| ) |
| + self.rotary_fn = apply_rotary_pos_emb |
| self.attn_logit_softcapping = self.config.attn_logit_softcapping |
| self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None |
| self.is_sliding = self.layer_type == "sliding_attention" |
| |
| |
| |
| |
| @@ -1254,6 +1254,7 @@ def __init__(self, config: Gemma3nTextConfig, layer_idx: int): |
| self.o_proj = nn.Linear( |
| config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias |
| ) |
| + self.rotary_fn = apply_rotary_pos_emb |
| self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None |
| self.is_sliding = self.layer_type == "sliding_attention" |
| |
| @@ -1475,6 +1476,7 @@ def __init__(self, config: Gemma3nConfig, layer_idx: int): |
| self.o_proj = nn.Linear( |
| config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias |
| ) |
| + self.rotary_fn = apply_rotary_pos_emb |
| self.attn_logit_softcapping = self.config.attn_logit_softcapping |
| self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None |
| |
| |
| |
| |
| |
| @@ -239,6 +239,7 @@ def __init__(self, config: GlmConfig, layer_idx: Optional[int] = None): |
| config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias |
| ) |
| self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) |
| + self.rotary_fn = apply_rotary_pos_emb |
| |
| def forward( |
| self, |
| |
| |
| |
| |
| @@ -221,6 +221,7 @@ def __init__(self, config: Glm4Config, layer_idx: Optional[int] = None): |
| config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias |
| ) |
| self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) |
| + self.rotary_fn = apply_rotary_pos_emb |
| |
| def forward( |
| self, |
| |
| |
| |
| |
| @@ -204,6 +204,48 @@ def rotate_half(x): |
| return torch.cat((-x2, x1), dim=-1) |
| |
| |
| +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| + """Applies Rotary Position Embedding to the query and key tensors. |
| + |
| + Args: |
| + q (`torch.Tensor`): The query tensor. |
| + k (`torch.Tensor`): The key tensor. |
| + cos (`torch.Tensor`): The cosine part of the rotary embedding. |
| + sin (`torch.Tensor`): The sine part of the rotary embedding. |
| + position_ids (`torch.Tensor`, *optional*): |
| + Deprecated and unused. |
| + unsqueeze_dim (`int`, *optional*, defaults to 1): |
| + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and |
| + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note |
| + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and |
| + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes |
| + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have |
| + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. |
| + Returns: |
| + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. |
| + """ |
| + cos = cos.unsqueeze(unsqueeze_dim) |
| + sin = sin.unsqueeze(unsqueeze_dim) |
| + |
| + # Interleave them instead of usual shape |
| + cos = cos[..., : cos.shape[-1] // 2].repeat_interleave(2, dim=-1) |
| + sin = sin[..., : sin.shape[-1] // 2].repeat_interleave(2, dim=-1) |
| + |
| + # Keep half or full tensor for later concatenation |
| + rotary_dim = cos.shape[-1] |
| + q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] |
| + k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] |
| + |
| + # Apply rotary embeddings on the first half or full tensor |
| + q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin) |
| + k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin) |
| + |
| + # Concatenate back to full shape |
| + q_embed = torch.cat([q_embed, q_pass], dim=-1) |
| + k_embed = torch.cat([k_embed, k_pass], dim=-1) |
| + return q_embed, k_embed |
| + |
| + |
| def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1): |
| """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). |
| |
| @@ -280,6 +322,7 @@ def __init__(self, config: Glm4vMoeTextConfig, layer_idx: Optional[int] = None): |
| config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias |
| ) |
| self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) |
| + self.rotary_fn = apply_rotary_pos_emb |
| self.rope_parameters = config.rope_parameters |
| |
| def forward( |
| |
| |
| |
| |
| @@ -325,6 +325,7 @@ def __init__(self, config: GptOssConfig, layer_idx: int): |
| self.o_proj = nn.Linear( |
| config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias |
| ) |
| + self.rotary_fn = apply_rotary_pos_emb |
| self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None |
| self.sinks = nn.Parameter(torch.empty(config.num_attention_heads)) |
| |
| |
| |
| |
| |
| @@ -28,7 +28,7 @@ |
| from ...activations import ACT2FN |
| from ...cache_utils import Cache, DynamicCache |
| from ...generation import GenerationMixin |
| -from ...integrations import use_kernel_forward_from_hub |
| +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub |
| from ...masking_utils import create_causal_mask |
| from ...modeling_layers import GradientCheckpointingLayer |
| from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast |
| @@ -50,6 +50,7 @@ def rotate_half(x): |
| return torch.cat((-x2, x1), dim=-1) |
| |
| |
| +@use_kernel_func_from_hub("rotary_pos_emb") |
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| """Applies Rotary Position Embedding to the query and key tensors. |
| |
| @@ -140,6 +141,7 @@ def __init__(self, config: GraniteConfig, layer_idx: Optional[int] = None): |
| self.o_proj = nn.Linear( |
| config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias |
| ) |
| + self.rotary_fn = apply_rotary_pos_emb |
| |
| def forward( |
| self, |
| |
| |
| |
| |
| @@ -30,7 +30,7 @@ |
| from ...activations import ACT2FN |
| from ...cache_utils import Cache, DynamicCache |
| from ...generation import GenerationMixin |
| -from ...integrations import use_kernel_forward_from_hub |
| +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub |
| from ...masking_utils import create_causal_mask |
| from ...modeling_layers import GradientCheckpointingLayer |
| from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast |
| @@ -272,6 +272,7 @@ def rotate_half(x): |
| return torch.cat((-x2, x1), dim=-1) |
| |
| |
| +@use_kernel_func_from_hub("rotary_pos_emb") |
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| """Applies Rotary Position Embedding to the query and key tensors. |
| |
| @@ -362,6 +363,7 @@ def __init__(self, config: GraniteMoeConfig, layer_idx: int): |
| self.o_proj = nn.Linear( |
| config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias |
| ) |
| + self.rotary_fn = apply_rotary_pos_emb |
| |
| def forward( |
| self, |
| |
| |
| |
| |
| @@ -31,7 +31,7 @@ |
| from ... import initialization as init |
| from ...cache_utils import Cache |
| from ...generation import GenerationMixin |
| -from ...integrations import use_kernel_forward_from_hub |
| +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub |
| from ...masking_utils import create_causal_mask |
| from ...modeling_layers import GradientCheckpointingLayer |
| from ...modeling_outputs import BaseModelOutputWithPast, MoeCausalLMOutputWithPast, MoeModelOutputWithPast |
| @@ -59,6 +59,41 @@ |
| logger = logging.get_logger(__name__) |
| |
| |
| +def rotate_half(x): |
| + """Rotates half the hidden dims of the input.""" |
| + x1 = x[..., : x.shape[-1] // 2] |
| + x2 = x[..., x.shape[-1] // 2 :] |
| + return torch.cat((-x2, x1), dim=-1) |
| + |
| + |
| +@use_kernel_func_from_hub("rotary_pos_emb") |
| +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| + """Applies Rotary Position Embedding to the query and key tensors. |
| + |
| + Args: |
| + q (`torch.Tensor`): The query tensor. |
| + k (`torch.Tensor`): The key tensor. |
| + cos (`torch.Tensor`): The cosine part of the rotary embedding. |
| + sin (`torch.Tensor`): The sine part of the rotary embedding. |
| + position_ids (`torch.Tensor`, *optional*): |
| + Deprecated and unused. |
| + unsqueeze_dim (`int`, *optional*, defaults to 1): |
| + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and |
| + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note |
| + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and |
| + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes |
| + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have |
| + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. |
| + Returns: |
| + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. |
| + """ |
| + cos = cos.unsqueeze(unsqueeze_dim) |
| + sin = sin.unsqueeze(unsqueeze_dim) |
| + q_embed = (q * cos) + (rotate_half(q) * sin) |
| + k_embed = (k * cos) + (rotate_half(k) * sin) |
| + return q_embed, k_embed |
| + |
| + |
| def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
| """ |
| This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, |
| @@ -122,6 +157,7 @@ def __init__(self, config: GraniteMoeHybridConfig, layer_idx: int): |
| self.o_proj = nn.Linear( |
| config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias |
| ) |
| + self.rotary_fn = apply_rotary_pos_emb |
| |
| def forward( |
| self, |
| |
| |
| |
| |
| @@ -30,7 +30,7 @@ |
| from ...activations import ACT2FN |
| from ...cache_utils import Cache, DynamicCache |
| from ...generation import GenerationMixin |
| -from ...integrations import use_kernel_forward_from_hub |
| +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub |
| from ...masking_utils import create_causal_mask |
| from ...modeling_layers import GradientCheckpointingLayer |
| from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast |
| @@ -262,6 +262,7 @@ def rotate_half(x): |
| return torch.cat((-x2, x1), dim=-1) |
| |
| |
| +@use_kernel_func_from_hub("rotary_pos_emb") |
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| """Applies Rotary Position Embedding to the query and key tensors. |
| |
| @@ -352,6 +353,7 @@ def __init__(self, config: GraniteMoeSharedConfig, layer_idx: int): |
| self.o_proj = nn.Linear( |
| config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias |
| ) |
| + self.rotary_fn = apply_rotary_pos_emb |
| |
| def forward( |
| self, |
| |
| |
| |
| |
| @@ -243,6 +243,7 @@ def __init__(self, config: HeliumConfig, layer_idx: Optional[int] = None): |
| config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias |
| ) |
| self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False) |
| + self.rotary_fn = apply_rotary_pos_emb |
| |
| def forward( |
| self, |
| |
| |
| |
| |
| @@ -30,7 +30,7 @@ |
| from ...activations import ACT2FN |
| from ...cache_utils import DynamicCache |
| from ...generation import GenerationMixin |
| -from ...integrations import use_kernel_forward_from_hub |
| +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub |
| from ...masking_utils import create_causal_mask |
| from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer |
| from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast |
| @@ -87,6 +87,7 @@ def rotate_half(x): |
| return torch.cat((-x2, x1), dim=-1) |
| |
| |
| +@use_kernel_func_from_hub("rotary_pos_emb") |
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| """Applies Rotary Position Embedding to the query and key tensors. |
| |
| @@ -177,6 +178,7 @@ def __init__(self, config: HunYuanDenseV1Config, layer_idx: int): |
| self.o_proj = nn.Linear( |
| config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias |
| ) |
| + self.rotary_fn = apply_rotary_pos_emb |
| self.query_layernorm = HunYuanDenseV1RMSNorm(self.head_dim, eps=config.rms_norm_eps) |
| self.key_layernorm = HunYuanDenseV1RMSNorm(self.head_dim, eps=config.rms_norm_eps) |
| |
| |
| |
| |
| |
| @@ -30,7 +30,7 @@ |
| from ...activations import ACT2FN |
| from ...cache_utils import Cache, DynamicCache |
| from ...generation import GenerationMixin |
| -from ...integrations import use_kernel_forward_from_hub |
| +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub |
| from ...masking_utils import create_causal_mask |
| from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer |
| from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast |
| @@ -86,6 +86,7 @@ def rotate_half(x): |
| return torch.cat((-x2, x1), dim=-1) |
| |
| |
| +@use_kernel_func_from_hub("rotary_pos_emb") |
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| """Applies Rotary Position Embedding to the query and key tensors. |
| |
| @@ -176,6 +177,7 @@ def __init__(self, config: HunYuanMoEV1Config, layer_idx: int): |
| self.o_proj = nn.Linear( |
| config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias |
| ) |
| + self.rotary_fn = apply_rotary_pos_emb |
| self.query_layernorm = HunYuanMoEV1RMSNorm(self.head_dim, eps=config.rms_norm_eps) |
| self.key_layernorm = HunYuanMoEV1RMSNorm(self.head_dim, eps=config.rms_norm_eps) |
| |
| |
| |
| |
| |
| @@ -32,7 +32,7 @@ |
| from ... import initialization as init |
| from ...activations import ACT2FN |
| from ...generation import GenerationMixin |
| -from ...integrations import use_kernel_forward_from_hub |
| +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub |
| from ...masking_utils import create_causal_mask |
| from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer |
| from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast |
| @@ -175,6 +175,41 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: |
| return self.key_cache[layer_idx].shape[-2] |
| |
| |
| +def rotate_half(x): |
| + """Rotates half the hidden dims of the input.""" |
| + x1 = x[..., : x.shape[-1] // 2] |
| + x2 = x[..., x.shape[-1] // 2 :] |
| + return torch.cat((-x2, x1), dim=-1) |
| + |
| + |
| +@use_kernel_func_from_hub("rotary_pos_emb") |
| +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| + """Applies Rotary Position Embedding to the query and key tensors. |
| + |
| + Args: |
| + q (`torch.Tensor`): The query tensor. |
| + k (`torch.Tensor`): The key tensor. |
| + cos (`torch.Tensor`): The cosine part of the rotary embedding. |
| + sin (`torch.Tensor`): The sine part of the rotary embedding. |
| + position_ids (`torch.Tensor`, *optional*): |
| + Deprecated and unused. |
| + unsqueeze_dim (`int`, *optional*, defaults to 1): |
| + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and |
| + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note |
| + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and |
| + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes |
| + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have |
| + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. |
| + Returns: |
| + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. |
| + """ |
| + cos = cos.unsqueeze(unsqueeze_dim) |
| + sin = sin.unsqueeze(unsqueeze_dim) |
| + q_embed = (q * cos) + (rotate_half(q) * sin) |
| + k_embed = (k * cos) + (rotate_half(k) * sin) |
| + return q_embed, k_embed |
| + |
| + |
| def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
| """ |
| This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, |
| @@ -229,6 +264,7 @@ def __init__(self, config: JambaConfig, layer_idx: int): |
| self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) |
| self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) |
| self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) |
| + self.rotary_fn = apply_rotary_pos_emb |
| |
| def forward( |
| self, |
| |
| |
| |
| |
| @@ -30,7 +30,7 @@ |
| from ...activations import ACT2FN |
| from ...cache_utils import Cache, DynamicCache |
| from ...generation import GenerationMixin |
| -from ...integrations import use_kernel_forward_from_hub |
| +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub |
| from ...masking_utils import create_causal_mask |
| from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer |
| from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast |
| @@ -366,6 +366,7 @@ def rotate_half(x): |
| return torch.cat((-x2, x1), dim=-1) |
| |
| |
| +@use_kernel_func_from_hub("rotary_pos_emb") |
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| """Applies Rotary Position Embedding to the query and key tensors. |
| |
| |
| |
| |
| |
| @@ -26,7 +26,7 @@ |
| |
| from ...cache_utils import Cache |
| from ...generation import GenerationMixin |
| -from ...integrations import use_kernel_forward_from_hub |
| +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub |
| from ...masking_utils import create_causal_mask |
| from ...modeling_layers import GradientCheckpointingLayer |
| from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast |
| @@ -292,6 +292,7 @@ def rotate_half(x): |
| return torch.cat((-x2, x1), dim=-1) |
| |
| |
| +@use_kernel_func_from_hub("rotary_pos_emb") |
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| """Applies Rotary Position Embedding to the query and key tensors. |
| |
| @@ -371,6 +372,7 @@ def __init__(self, config: Lfm2Config, layer_idx: int): |
| self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False) |
| self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) |
| self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) |
| + self.rotary_fn = apply_rotary_pos_emb |
| self.out_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) |
| self.q_layernorm = Lfm2RMSNorm(self.head_dim, eps=config.norm_eps) |
| self.k_layernorm = Lfm2RMSNorm(self.head_dim, eps=config.norm_eps) |
| |
| |
| |
| |
| @@ -28,7 +28,7 @@ |
| from ...activations import ACT2FN |
| from ...cache_utils import Cache |
| from ...generation import GenerationMixin |
| -from ...integrations import use_kernel_forward_from_hub |
| +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub |
| from ...masking_utils import create_causal_mask |
| from ...modeling_layers import GradientCheckpointingLayer |
| from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, MoeModelOutputWithPast |
| @@ -363,6 +363,7 @@ def rotate_half(x): |
| return torch.cat((-x2, x1), dim=-1) |
| |
| |
| +@use_kernel_func_from_hub("rotary_pos_emb") |
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| """Applies Rotary Position Embedding to the query and key tensors. |
| |
| @@ -442,6 +443,7 @@ def __init__(self, config: Lfm2MoeConfig, layer_idx: int): |
| self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False) |
| self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) |
| self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) |
| + self.rotary_fn = apply_rotary_pos_emb |
| self.out_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) |
| self.q_layernorm = Lfm2MoeRMSNorm(self.head_dim, eps=config.norm_eps) |
| self.k_layernorm = Lfm2MoeRMSNorm(self.head_dim, eps=config.norm_eps) |
| |
| |
| |
| |
| @@ -199,6 +199,7 @@ def __init__(self, config: LightGlueConfig, layer_idx: int): |
| self.o_proj = nn.Linear( |
| config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias |
| ) |
| + self.rotary_fn = apply_rotary_pos_emb |
| |
| def forward( |
| self, |
| |
| |
| |
| |
| @@ -26,7 +26,7 @@ |
| from ...activations import ACT2FN |
| from ...cache_utils import Cache, DynamicCache |
| from ...generation import GenerationMixin |
| -from ...integrations import use_kernel_forward_from_hub |
| +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub |
| from ...masking_utils import create_causal_mask |
| from ...modeling_layers import ( |
| GenericForQuestionAnswering, |
| @@ -142,6 +142,7 @@ def rotate_half(x): |
| return torch.cat((-x2, x1), dim=-1) |
| |
| |
| +@use_kernel_func_from_hub("rotary_pos_emb") |
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| """Applies Rotary Position Embedding to the query and key tensors. |
| |
| @@ -248,6 +249,7 @@ def __init__(self, config: LlamaConfig, layer_idx: int): |
| self.o_proj = nn.Linear( |
| config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias |
| ) |
| + self.rotary_fn = apply_rotary_pos_emb |
| |
| def forward( |
| self, |
| |
| |
| |
| |
| @@ -31,7 +31,7 @@ |
| from ...activations import ACT2FN |
| from ...cache_utils import Cache, DynamicCache |
| from ...generation import GenerationMixin |
| -from ...integrations import use_kernel_forward_from_hub |
| +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub |
| from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask |
| from ...modeling_flash_attention_utils import FlashAttentionKwargs |
| from ...modeling_layers import ( |
| @@ -326,6 +326,7 @@ def rotate_half(x): |
| return torch.cat((-x2, x1), dim=-1) |
| |
| |
| +@use_kernel_func_from_hub("rotary_pos_emb") |
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| """Applies Rotary Position Embedding to the query and key tensors. |
| |
| @@ -407,6 +408,7 @@ def __init__(self, config: MiniMaxConfig, layer_idx: int): |
| self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) |
| self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) |
| self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) |
| + self.rotary_fn = apply_rotary_pos_emb |
| |
| def forward( |
| self, |
| |
| |
| |
| |
| @@ -13,7 +13,7 @@ |
| from ...activations import ACT2FN |
| from ...cache_utils import Cache, DynamicCache |
| from ...generation import GenerationMixin |
| -from ...integrations import use_kernel_forward_from_hub |
| +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub |
| from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask |
| from ...modeling_flash_attention_utils import FlashAttentionKwargs |
| from ...modeling_layers import ( |
| @@ -54,6 +54,7 @@ def rotate_half(x): |
| return torch.cat((-x2, x1), dim=-1) |
| |
| |
| +@use_kernel_func_from_hub("rotary_pos_emb") |
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| """Applies Rotary Position Embedding to the query and key tensors. |
| |
| @@ -137,6 +138,7 @@ def __init__(self, config, layer_idx: int): |
| self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) |
| self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) |
| self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) |
| + self.rotary_fn = apply_rotary_pos_emb |
| self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None |
| |
| def forward( |
| |
| |
| |
| |
| @@ -15,7 +15,7 @@ |
| from ...activations import ACT2FN |
| from ...cache_utils import Cache, DynamicCache |
| from ...generation import GenerationMixin |
| -from ...integrations import use_kernel_forward_from_hub |
| +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub |
| from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask |
| from ...modeling_flash_attention_utils import FlashAttentionKwargs |
| from ...modeling_layers import ( |
| @@ -55,6 +55,7 @@ def rotate_half(x): |
| return torch.cat((-x2, x1), dim=-1) |
| |
| |
| +@use_kernel_func_from_hub("rotary_pos_emb") |
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| """Applies Rotary Position Embedding to the query and key tensors. |
| |
| @@ -136,6 +137,7 @@ def __init__(self, config: MistralConfig, layer_idx: int): |
| self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) |
| self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) |
| self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) |
| + self.rotary_fn = apply_rotary_pos_emb |
| |
| def forward( |
| self, |
| |
| |
| |
| |
| @@ -37,7 +37,7 @@ |
| from ...activations import ACT2FN |
| from ...cache_utils import Cache, DynamicCache |
| from ...generation import GenerationMixin |
| -from ...integrations import use_kernel_forward_from_hub |
| +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub |
| from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask |
| from ...modeling_flash_attention_utils import FlashAttentionKwargs |
| from ...modeling_layers import ( |
| @@ -225,6 +225,7 @@ def rotate_half(x): |
| return torch.cat((-x2, x1), dim=-1) |
| |
| |
| +@use_kernel_func_from_hub("rotary_pos_emb") |
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| """Applies Rotary Position Embedding to the query and key tensors. |
| |
| @@ -306,6 +307,7 @@ def __init__(self, config: MixtralConfig, layer_idx: int): |
| self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) |
| self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) |
| self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) |
| + self.rotary_fn = apply_rotary_pos_emb |
| |
| def forward( |
| self, |
| |
| |
| |
| |
| @@ -31,6 +31,7 @@ |
| |
| from ... import initialization as init |
| from ...activations import ACT2FN |
| +from ...integrations import use_kernel_func_from_hub |
| from ...modeling_attn_mask_utils import _prepare_4d_attention_mask |
| from ...modeling_layers import GradientCheckpointingLayer |
| from ...modeling_outputs import ( |
| @@ -331,6 +332,7 @@ def rotate_half(x): |
| return torch.cat((-x2, x1), dim=-1) |
| |
| |
| +@use_kernel_func_from_hub("rotary_pos_emb") |
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| """Applies Rotary Position Embedding to the query and key tensors. |
| |
| |
| |
| |
| |
| @@ -31,6 +31,7 @@ |
| from ...activations import ACT2FN |
| from ...cache_utils import Cache, DynamicCache |
| from ...generation import GenerationMixin |
| +from ...integrations import use_kernel_func_from_hub |
| from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask |
| from ...modeling_layers import GradientCheckpointingLayer |
| from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast |
| @@ -183,6 +184,7 @@ def rotate_half(x): |
| return torch.cat((-x2, x1), dim=-1) |
| |
| |
| +@use_kernel_func_from_hub("rotary_pos_emb") |
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| """Applies Rotary Position Embedding to the query and key tensors. |
| |
| |
| |
| |
| |
| @@ -264,6 +264,7 @@ def __init__( |
| config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias |
| ) |
| self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) |
| + self.rotary_fn = apply_rotary_pos_emb |
| |
| # Pad head dimension to the next specified multiple. |
| if self.config.pad_head_dim_to_multiple_of is not None: |
| |
| |
| |
| |
| @@ -237,6 +237,7 @@ def __init__(self, config: OlmoConfig, layer_idx: int): |
| self.o_proj = nn.Linear( |
| config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias |
| ) |
| + self.rotary_fn = apply_rotary_pos_emb |
| |
| def forward( |
| self, |
| |
| |
| |
| |
| @@ -230,6 +230,7 @@ def __init__(self, config: Olmo2Config, layer_idx: Optional[int] = None): |
| self.o_proj = nn.Linear( |
| config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias |
| ) |
| + self.rotary_fn = apply_rotary_pos_emb |
| self.q_norm = Olmo2RMSNorm(config.num_attention_heads * self.head_dim, config.rms_norm_eps) |
| self.k_norm = Olmo2RMSNorm(config.num_key_value_heads * self.head_dim, config.rms_norm_eps) |
| |
| |
| |
| |
| |
| @@ -161,6 +161,7 @@ def __init__(self, config: Olmo3Config, layer_idx: int): |
| self.o_proj = nn.Linear( |
| config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias |
| ) |
| + self.rotary_fn = apply_rotary_pos_emb |
| self.q_norm = Olmo3RMSNorm(config.num_attention_heads * self.head_dim, config.rms_norm_eps) |
| self.k_norm = Olmo3RMSNorm(config.num_key_value_heads * self.head_dim, config.rms_norm_eps) |
| assert config.layer_types is not None |
| |
| |
| |
| |
| @@ -27,7 +27,7 @@ |
| from ...activations import ACT2FN |
| from ...cache_utils import Cache, DynamicCache |
| from ...generation import GenerationMixin |
| -from ...integrations import use_kernel_forward_from_hub |
| +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub |
| from ...masking_utils import create_causal_mask |
| from ...modeling_layers import GradientCheckpointingLayer |
| from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast |
| @@ -148,6 +148,7 @@ def rotate_half(x): |
| return torch.cat((-x2, x1), dim=-1) |
| |
| |
| +@use_kernel_func_from_hub("rotary_pos_emb") |
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| """Applies Rotary Position Embedding to the query and key tensors. |
| |
| @@ -238,6 +239,7 @@ def __init__(self, config: OlmoeConfig, layer_idx: Optional[int] = None): |
| self.o_proj = nn.Linear( |
| config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias |
| ) |
| + self.rotary_fn = apply_rotary_pos_emb |
| self.q_norm = OlmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| self.k_norm = OlmoeRMSNorm( |
| (config.hidden_size // config.num_attention_heads) * config.num_key_value_heads, eps=config.rms_norm_eps |
| |
| |
| |
| |
| @@ -29,6 +29,7 @@ |
| |
| from ... import initialization as init |
| from ...activations import ACT2FN |
| +from ...integrations import use_kernel_func_from_hub |
| from ...modeling_layers import GradientCheckpointingLayer |
| from ...modeling_outputs import BaseModelOutput, CausalLMOutput |
| from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel |
| @@ -182,6 +183,41 @@ def forward(self, hidden_states, attention_mask=None): |
| return hidden_states.transpose(1, 2) |
| |
| |
| +def rotate_half(x): |
| + """Rotates half the hidden dims of the input.""" |
| + x1 = x[..., : x.shape[-1] // 2] |
| + x2 = x[..., x.shape[-1] // 2 :] |
| + return torch.cat((-x2, x1), dim=-1) |
| + |
| + |
| +@use_kernel_func_from_hub("rotary_pos_emb") |
| +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| + """Applies Rotary Position Embedding to the query and key tensors. |
| + |
| + Args: |
| + q (`torch.Tensor`): The query tensor. |
| + k (`torch.Tensor`): The key tensor. |
| + cos (`torch.Tensor`): The cosine part of the rotary embedding. |
| + sin (`torch.Tensor`): The sine part of the rotary embedding. |
| + position_ids (`torch.Tensor`, *optional*): |
| + Deprecated and unused. |
| + unsqueeze_dim (`int`, *optional*, defaults to 1): |
| + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and |
| + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note |
| + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and |
| + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes |
| + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have |
| + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. |
| + Returns: |
| + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. |
| + """ |
| + cos = cos.unsqueeze(unsqueeze_dim) |
| + sin = sin.unsqueeze(unsqueeze_dim) |
| + q_embed = (q * cos) + (rotate_half(q) * sin) |
| + k_embed = (k * cos) + (rotate_half(k) * sin) |
| + return q_embed, k_embed |
| + |
| + |
| def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
| """ |
| This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, |
| @@ -245,6 +281,7 @@ def __init__(self, config: ParakeetEncoderConfig, layer_idx: int): |
| self.o_proj = nn.Linear( |
| config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias |
| ) |
| + self.rotary_fn = apply_rotary_pos_emb |
| # W_{k,R} projection |
| self.relative_k_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False) |
| # global content bias |
| |
| |
| |
| |
| @@ -13,6 +13,7 @@ |
| from ...activations import ACT2FN |
| from ...cache_utils import Cache, DynamicCache |
| from ...generation import GenerationMixin |
| +from ...integrations import use_kernel_func_from_hub |
| from ...masking_utils import create_causal_mask |
| from ...modeling_layers import ( |
| GenericForSequenceClassification, |
| @@ -105,6 +106,7 @@ def rotate_half(x): |
| return torch.cat((-x2, x1), dim=-1) |
| |
| |
| +@use_kernel_func_from_hub("rotary_pos_emb") |
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| """Applies Rotary Position Embedding to the query and key tensors. |
| |
| @@ -185,6 +187,7 @@ def __init__(self, config: PhiConfig, layer_idx: int): |
| self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True) |
| self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) |
| self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) |
| + self.rotary_fn = apply_rotary_pos_emb |
| self.dense = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=True) |
| self.rotary_ndims = int(self.head_dim * config.partial_rotary_factor) |
| self.qk_layernorm = config.qk_layernorm |
| |
| |
| |
| |
| @@ -30,7 +30,7 @@ |
| from ...activations import ACT2FN |
| from ...cache_utils import Cache, DynamicCache |
| from ...generation import GenerationMixin |
| -from ...integrations import use_kernel_forward_from_hub |
| +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub |
| from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask |
| from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer |
| from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast |
| @@ -128,6 +128,7 @@ def rotate_half(x): |
| return torch.cat((-x2, x1), dim=-1) |
| |
| |
| +@use_kernel_func_from_hub("rotary_pos_emb") |
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| """Applies Rotary Position Embedding to the query and key tensors. |
| |
| @@ -218,6 +219,7 @@ def __init__(self, config: PhimoeConfig, layer_idx: int): |
| self.o_proj = nn.Linear( |
| config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias |
| ) |
| + self.rotary_fn = apply_rotary_pos_emb |
| |
| def forward( |
| self, |
| |
| |
| |
| |
| @@ -13,7 +13,7 @@ |
| from ...activations import ACT2FN |
| from ...cache_utils import Cache, DynamicCache |
| from ...generation import GenerationMixin |
| -from ...integrations import use_kernel_forward_from_hub |
| +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub |
| from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask |
| from ...modeling_flash_attention_utils import FlashAttentionKwargs |
| from ...modeling_layers import ( |
| @@ -119,6 +119,7 @@ def rotate_half(x): |
| return torch.cat((-x2, x1), dim=-1) |
| |
| |
| +@use_kernel_func_from_hub("rotary_pos_emb") |
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| """Applies Rotary Position Embedding to the query and key tensors. |
| |
| @@ -201,6 +202,7 @@ def __init__(self, config: Qwen2Config, layer_idx: int): |
| self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) |
| self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) |
| self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) |
| + self.rotary_fn = apply_rotary_pos_emb |
| self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None |
| |
| def forward( |
| |
| |
| |
| |
| @@ -35,7 +35,7 @@ |
| from ...activations import ACT2FN |
| from ...cache_utils import Cache, DynamicCache |
| from ...generation import GenerationMixin |
| -from ...integrations import use_kernel_forward_from_hub |
| +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub |
| from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask |
| from ...modeling_layers import ( |
| GenericForQuestionAnswering, |
| @@ -161,6 +161,7 @@ def rotate_half(x): |
| return torch.cat((-x2, x1), dim=-1) |
| |
| |
| +@use_kernel_func_from_hub("rotary_pos_emb") |
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| """Applies Rotary Position Embedding to the query and key tensors. |
| |
| @@ -243,6 +244,7 @@ def __init__(self, config: Qwen2MoeConfig, layer_idx: int): |
| self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.qkv_bias) |
| self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.qkv_bias) |
| self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) |
| + self.rotary_fn = apply_rotary_pos_emb |
| if self.config.layer_types[layer_idx] == "sliding_attention": |
| self.sliding_window = config.sliding_window |
| |
| |
| |
| |
| |
| @@ -28,7 +28,7 @@ |
| from ...activations import ACT2FN |
| from ...cache_utils import Cache, DynamicCache |
| from ...generation import GenerationMixin |
| -from ...integrations import use_kernel_forward_from_hub |
| +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub |
| from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask |
| from ...modeling_flash_attention_utils import FlashAttentionKwargs |
| from ...modeling_layers import ( |
| @@ -155,6 +155,7 @@ def rotate_half(x): |
| return torch.cat((-x2, x1), dim=-1) |
| |
| |
| +@use_kernel_func_from_hub("rotary_pos_emb") |
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| """Applies Rotary Position Embedding to the query and key tensors. |
| |
| @@ -246,6 +247,7 @@ def __init__(self, config: Qwen3Config, layer_idx: int): |
| self.o_proj = nn.Linear( |
| config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias |
| ) |
| + self.rotary_fn = apply_rotary_pos_emb |
| self.q_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! |
| self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape |
| self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None |
| |
| |
| |
| |
| @@ -30,7 +30,7 @@ |
| from ...activations import ACT2FN |
| from ...cache_utils import Cache, DynamicCache |
| from ...generation import GenerationMixin |
| -from ...integrations import use_kernel_forward_from_hub |
| +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub |
| from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask |
| from ...modeling_flash_attention_utils import FlashAttentionKwargs |
| from ...modeling_layers import ( |
| @@ -55,6 +55,7 @@ def rotate_half(x): |
| return torch.cat((-x2, x1), dim=-1) |
| |
| |
| +@use_kernel_func_from_hub("rotary_pos_emb") |
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| """Applies Rotary Position Embedding to the query and key tensors. |
| |
| @@ -145,6 +146,7 @@ def __init__(self, config: Qwen3MoeConfig, layer_idx: int): |
| self.o_proj = nn.Linear( |
| config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias |
| ) |
| + self.rotary_fn = apply_rotary_pos_emb |
| self.q_norm = Qwen3MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! |
| self.k_norm = Qwen3MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape |
| self.sliding_window = getattr(config, "sliding_window", None) |
| |
| |
| |
| |
| @@ -371,6 +371,7 @@ def __init__(self, config: Qwen3NextConfig, layer_idx: int): |
| self.o_proj = nn.Linear( |
| config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias |
| ) |
| + self.rotary_fn = apply_rotary_pos_emb |
| self.q_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! |
| self.k_norm = Qwen3NextRMSNorm( |
| self.head_dim, eps=config.rms_norm_eps |
| |
| |
| |
| |
| @@ -35,7 +35,7 @@ |
| from ...activations import ACT2FN |
| from ...cache_utils import Cache, DynamicCache |
| from ...generation import GenerationMixin |
| -from ...integrations import use_kernel_forward_from_hub |
| +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub |
| from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask |
| from ...modeling_flash_attention_utils import FlashAttentionKwargs |
| from ...modeling_layers import GradientCheckpointingLayer |
| @@ -1415,6 +1415,7 @@ def extra_repr(self): |
| return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" |
| |
| |
| +@use_kernel_func_from_hub("rotary_pos_emb") |
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| """Applies Rotary Position Embedding to the query and key tensors. |
| |
| @@ -1467,6 +1468,7 @@ def __init__(self, config, layer_idx): |
| self.o_proj = nn.Linear( |
| config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias |
| ) |
| + self.rotary_fn = apply_rotary_pos_emb |
| self.q_norm = Qwen3OmniMoeThinkerTextRMSNorm( |
| self.head_dim, eps=config.rms_norm_eps |
| ) # unlike olmo, only on the head dim! |
| @@ -2348,6 +2350,7 @@ def __init__(self, config: Qwen3OmniMoeConfig, layer_idx: int): |
| self.o_proj = nn.Linear( |
| config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias |
| ) |
| + self.rotary_fn = apply_rotary_pos_emb |
| self.q_norm = Qwen3OmniMoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! |
| self.k_norm = Qwen3OmniMoeRMSNorm( |
| self.head_dim, eps=config.rms_norm_eps |
| @@ -3377,6 +3380,7 @@ def __init__(self, config: Qwen3OmniMoeCode2WavConfig, layer_idx): |
| self.o_proj = nn.Linear( |
| config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias |
| ) |
| + self.rotary_fn = apply_rotary_pos_emb |
| self.q_norm = nn.Identity() |
| self.k_norm = nn.Identity() |
| self.sliding_window = config.sliding_window |
| |
| |
| |
| |
| @@ -30,7 +30,7 @@ |
| from ...activations import ACT2FN |
| from ...cache_utils import Cache, DynamicCache |
| from ...generation import GenerationMixin |
| -from ...integrations import use_kernel_forward_from_hub |
| +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub |
| from ...masking_utils import create_causal_mask |
| from ...modeling_flash_attention_utils import FlashAttentionKwargs |
| from ...modeling_layers import GradientCheckpointingLayer |
| @@ -385,6 +385,7 @@ def extra_repr(self): |
| return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" |
| |
| |
| +@use_kernel_func_from_hub("rotary_pos_emb") |
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| """Applies Rotary Position Embedding to the query and key tensors. |
| |
| @@ -438,6 +439,7 @@ def __init__(self, config: Qwen3VLTextConfig, layer_idx: int): |
| self.o_proj = nn.Linear( |
| config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias |
| ) |
| + self.rotary_fn = apply_rotary_pos_emb |
| self.q_norm = Qwen3VLTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! |
| self.k_norm = Qwen3VLTextRMSNorm( |
| self.head_dim, eps=config.rms_norm_eps |
| |
| |
| |
| |
| @@ -31,7 +31,7 @@ |
| from ...activations import ACT2FN |
| from ...cache_utils import Cache, DynamicCache |
| from ...generation import GenerationMixin |
| -from ...integrations import use_kernel_forward_from_hub |
| +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub |
| from ...masking_utils import create_causal_mask |
| from ...modeling_flash_attention_utils import FlashAttentionKwargs |
| from ...modeling_layers import GradientCheckpointingLayer |
| @@ -198,6 +198,7 @@ def eager_attention_forward( |
| return attn_output, attn_weights |
| |
| |
| +@use_kernel_func_from_hub("rotary_pos_emb") |
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| """Applies Rotary Position Embedding to the query and key tensors. |
| |
| @@ -251,6 +252,7 @@ def __init__(self, config: Qwen3VLMoeTextConfig, layer_idx: int): |
| self.o_proj = nn.Linear( |
| config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias |
| ) |
| + self.rotary_fn = apply_rotary_pos_emb |
| self.q_norm = Qwen3VLMoeTextRMSNorm( |
| self.head_dim, eps=config.rms_norm_eps |
| ) # unlike olmo, only on the head dim! |
| |
| |
| |
| |
| @@ -27,7 +27,7 @@ |
| from ...activations import ACT2FN |
| from ...cache_utils import Cache, DynamicCache |
| from ...generation import GenerationMixin |
| -from ...integrations import use_kernel_forward_from_hub |
| +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub |
| from ...masking_utils import create_causal_mask |
| from ...modeling_layers import ( |
| GenericForQuestionAnswering, |
| @@ -90,6 +90,7 @@ def rotate_half(x): |
| return torch.cat((-x2, x1), dim=-1) |
| |
| |
| +@use_kernel_func_from_hub("rotary_pos_emb") |
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| """Applies Rotary Position Embedding to the query and key tensors. |
| |
| |
| |
| |
| |
| @@ -28,7 +28,7 @@ |
| from ...activations import ACT2FN |
| from ...cache_utils import Cache, DynamicCache |
| from ...generation import GenerationMixin |
| -from ...integrations import use_kernel_forward_from_hub |
| +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub |
| from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask |
| from ...modeling_flash_attention_utils import FlashAttentionKwargs |
| from ...modeling_layers import ( |
| @@ -118,6 +118,7 @@ def rotate_half(x): |
| return torch.cat((-x2, x1), dim=-1) |
| |
| |
| +@use_kernel_func_from_hub("rotary_pos_emb") |
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| """Applies Rotary Position Embedding to the query and key tensors. |
| |
| @@ -208,6 +209,7 @@ def __init__(self, config: SmolLM3Config, layer_idx: int): |
| self.o_proj = nn.Linear( |
| config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias |
| ) |
| + self.rotary_fn = apply_rotary_pos_emb |
| |
| self.use_rope = config.no_rope_layers[layer_idx] |
| self.sliding_window = ( |
| |
| |
| |
| |
| @@ -35,6 +35,7 @@ |
| from ...activations import ACT2FN |
| from ...cache_utils import Cache, DynamicCache |
| from ...generation import GenerationMixin |
| +from ...integrations import use_kernel_func_from_hub |
| from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask |
| from ...modeling_flash_attention_utils import FlashAttentionKwargs |
| from ...modeling_layers import ( |
| @@ -74,6 +75,7 @@ def rotate_half(x): |
| return torch.cat((-x2, x1), dim=-1) |
| |
| |
| +@use_kernel_func_from_hub("rotary_pos_emb") |
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| """Applies Rotary Position Embedding to the query and key tensors. |
| |
| @@ -155,6 +157,7 @@ def __init__(self, config: Starcoder2Config, layer_idx: Optional[int] = None): |
| self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) |
| self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) |
| self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.use_bias) |
| + self.rotary_fn = apply_rotary_pos_emb |
| self.residual_dropout = config.residual_dropout |
| |
| def forward( |
| |
| |
| |
| |
| @@ -29,6 +29,7 @@ |
| from ...activations import ACT2FN |
| from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache |
| from ...generation import GenerationMixin |
| +from ...integrations import use_kernel_func_from_hub |
| from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask |
| from ...modeling_flash_attention_utils import FlashAttentionKwargs |
| from ...modeling_layers import GradientCheckpointingLayer |
| @@ -162,6 +163,7 @@ def rotate_half(x): |
| return torch.cat((-x2, x1), dim=-1) |
| |
| |
| +@use_kernel_func_from_hub("rotary_pos_emb") |
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| """Applies Rotary Position Embedding to the query and key tensors. |
| |
| @@ -263,6 +265,7 @@ def __init__(self, config: T5GemmaModuleConfig, layer_idx: int): |
| self.o_proj = nn.Linear( |
| config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias |
| ) |
| + self.rotary_fn = apply_rotary_pos_emb |
| self.attn_logit_softcapping = self.config.attn_logit_softcapping |
| self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None |
| |
| @@ -338,6 +341,7 @@ def __init__(self, config: T5GemmaModuleConfig, layer_idx: int): |
| self.o_proj = nn.Linear( |
| config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias |
| ) |
| + self.rotary_fn = apply_rotary_pos_emb |
| self.attn_logit_softcapping = self.config.attn_logit_softcapping |
| |
| if config.cross_attention_hidden_size is None: |
| |
| |
| |
| |
| @@ -29,6 +29,7 @@ |
| from ...activations import ACT2FN |
| from ...cache_utils import Cache, DynamicCache |
| from ...generation import GenerationMixin |
| +from ...integrations import use_kernel_func_from_hub |
| from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask |
| from ...modeling_flash_attention_utils import FlashAttentionKwargs |
| from ...modeling_layers import GradientCheckpointingLayer |
| @@ -87,6 +88,7 @@ def rotate_half(x): |
| return torch.cat((-x2, x1), dim=-1) |
| |
| |
| +@use_kernel_func_from_hub("rotary_pos_emb") |
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| """Applies Rotary Position Embedding to the query and key tensors. |
| |
| @@ -187,6 +189,7 @@ def __init__(self, config: VaultGemmaConfig, layer_idx: int): |
| self.o_proj = nn.Linear( |
| config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias |
| ) |
| + self.rotary_fn = apply_rotary_pos_emb |
| self.attn_logit_softcapping = self.config.attn_logit_softcapping |
| self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None |
| |
| |
| |
| |
| |
| @@ -32,6 +32,7 @@ |
| from ...activations import ACT2FN |
| from ...cache_utils import Cache |
| from ...generation import GenerationMixin |
| +from ...integrations import use_kernel_func_from_hub |
| from ...modeling_attn_mask_utils import AttentionMaskConverter |
| from ...modeling_flash_attention_utils import FlashAttentionKwargs |
| from ...modeling_layers import GradientCheckpointingLayer |
| @@ -316,6 +317,7 @@ def rotate_half(x): |
| return torch.cat((-x2, x1), dim=-1) |
| |
| |
| +@use_kernel_func_from_hub("rotary_pos_emb") |
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| """Applies Rotary Position Embedding to the query and key tensors. |
| |
|
|