harness / diffs /41147.patch
ArthurZ's picture
ArthurZ HF Staff
Initial harness: 100 perf tasks + Gradio browser
dfefe0b verified
diff --git a/src/transformers/integrations/__init__.py b/src/transformers/integrations/__init__.py
index 15dd7518150c..05f410c4437f 100755
--- a/src/transformers/integrations/__init__.py
+++ b/src/transformers/integrations/__init__.py
@@ -72,6 +72,7 @@
"register_kernel_mapping",
"replace_kernel_forward_from_hub",
"use_kernel_forward_from_hub",
+ "use_kernel_func_from_hub",
],
"integration_utils": [
"INTEGRATION_TO_CALLBACK",
@@ -212,6 +213,7 @@
register_kernel_mapping,
replace_kernel_forward_from_hub,
use_kernel_forward_from_hub,
+ use_kernel_func_from_hub,
)
from .integration_utils import (
INTEGRATION_TO_CALLBACK,
diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py
index 0ab866ecbd6d..bada38965aed 100644
--- a/src/transformers/integrations/hub_kernels.py
+++ b/src/transformers/integrations/hub_kernels.py
@@ -34,6 +34,23 @@
register_kernel_mapping,
replace_kernel_forward_from_hub,
)
+ from kernels import (
+ use_kernel_forward_from_hub as _kernels_use_kernel_forward_from_hub,
+ )
+
+ # Try to import FuncRepository, fallback if not available
+ try:
+ from kernels import FuncRepository
+ except ImportError:
+ FuncRepository = None
+
+ # Try to import use_kernel_func_from_hub, fallback if not available
+ try:
+ from kernels import use_kernel_func_from_hub as _kernels_use_kernel_func_from_hub
+
+ _has_use_kernel_func_from_hub = True
+ except ImportError:
+ _has_use_kernel_func_from_hub = False
_TRANSFORMERS_USE_HUB_KERNELS = os.environ.get("USE_HUB_KERNELS", "YES").upper()
_kernels_available = True
@@ -41,8 +58,6 @@
def use_kernel_forward_from_hub(layer_name: str):
if _kernels_enabled:
- from kernels import use_kernel_forward_from_hub as _kernels_use_kernel_forward_from_hub
-
return _kernels_use_kernel_forward_from_hub(layer_name)
else:
logger.warning_once(
@@ -50,6 +65,21 @@ def use_kernel_forward_from_hub(layer_name: str):
)
return lambda cls: cls
+ def use_kernel_func_from_hub(func_name: str):
+ if _kernels_enabled and _has_use_kernel_func_from_hub:
+ return _kernels_use_kernel_func_from_hub(func_name)
+ else:
+ if not _has_use_kernel_func_from_hub:
+ logger.warning_once(
+ "use_kernel_func_from_hub is not available in the installed kernels version. "
+ "Please upgrade kernels to use this feature."
+ )
+ else:
+ logger.warning_once(
+ f"kernels hub usage is disabled through the environment USE_HUB_KERNELS={_TRANSFORMERS_USE_HUB_KERNELS}"
+ )
+ return lambda func: func
+
_KERNEL_MAPPING: dict[str, dict[Device | str, LayerRepository]] = {
"MultiScaleDeformableAttention": {
"cuda": LayerRepository(
@@ -164,6 +194,16 @@ def use_kernel_forward_from_hub(layer_name: str):
},
}
+ # Add function kernel mappings if FuncRepository is available
+ if FuncRepository is not None:
+ _KERNEL_MAPPING["rotary_pos_emb"] = {
+ "xpu": {
+ Mode.INFERENCE: FuncRepository(
+ repo_id="kernels-community/rotary", func_name="apply_rotary_transformers"
+ )
+ }
+ }
+
def has_key(d, key):
return key in d or any(isinstance(v, dict) and has_key(v, key) for v in d.values())
@@ -189,6 +229,12 @@ def decorator(cls):
return decorator
+ def use_kernel_func_from_hub(*args, **kwargs):
+ def decorator(func):
+ return func
+
+ return decorator
+
class LayerRepository:
def __init__(self, *args, **kwargs):
raise RuntimeError("LayerRepository requires `kernels` to be installed. Run `pip install kernels`.")
@@ -201,6 +247,11 @@ def replace_kernel_forward_from_hub(*args, **kwargs):
def register_kernel_mapping(*args, **kwargs):
raise RuntimeError("register_kernel_mapping requires `kernels` to be installed. Run `pip install kernels`.")
+ def register_kernel_mapping_transformers(*args, **kwargs):
+ raise RuntimeError(
+ "register_kernel_mapping_transformers requires `kernels` to be installed. Run `pip install kernels`."
+ )
+
_HUB_KERNEL_MAPPING: dict[str, dict[str, str]] = {
"causal-conv1d": {"repo_id": "kernels-community/causal-conv1d"},
@@ -319,6 +370,7 @@ def lazy_load_kernel(kernel_name: str, mapping: dict[str, ModuleType | None] = _
__all__ = [
"LayerRepository",
"use_kernel_forward_from_hub",
+ "use_kernel_func_from_hub",
"register_kernel_mapping",
"register_kernel_mapping_transformers",
"replace_kernel_forward_from_hub",
diff --git a/src/transformers/models/apertus/modeling_apertus.py b/src/transformers/models/apertus/modeling_apertus.py
index 77a3d65478d6..1780b6383cc4 100644
--- a/src/transformers/models/apertus/modeling_apertus.py
+++ b/src/transformers/models/apertus/modeling_apertus.py
@@ -28,7 +28,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
-from ...integrations import use_kernel_forward_from_hub
+from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
from ...masking_utils import create_causal_mask
from ...modeling_layers import GenericForTokenClassification, GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
@@ -147,6 +147,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
+@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
@@ -237,6 +238,7 @@ def __init__(self, config: ApertusConfig, layer_idx: Optional[int] = None):
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
+ self.rotary_fn = apply_rotary_pos_emb
self.q_norm = ApertusRMSNorm(self.head_dim, config.rms_norm_eps)
self.k_norm = ApertusRMSNorm(self.head_dim, config.rms_norm_eps)
diff --git a/src/transformers/models/arcee/modeling_arcee.py b/src/transformers/models/arcee/modeling_arcee.py
index 779f4a63e378..d7faffd25c6e 100644
--- a/src/transformers/models/arcee/modeling_arcee.py
+++ b/src/transformers/models/arcee/modeling_arcee.py
@@ -30,7 +30,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
-from ...integrations import use_kernel_forward_from_hub
+from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
from ...masking_utils import create_causal_mask
from ...modeling_layers import (
GenericForQuestionAnswering,
@@ -154,6 +154,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
+@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
@@ -244,6 +245,7 @@ def __init__(self, config: ArceeConfig, layer_idx: int):
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
+ self.rotary_fn = apply_rotary_pos_emb
def forward(
self,
diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py
index 96a6a82da91d..998f770a4562 100644
--- a/src/transformers/models/aria/modeling_aria.py
+++ b/src/transformers/models/aria/modeling_aria.py
@@ -29,7 +29,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
-from ...integrations import use_kernel_forward_from_hub
+from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
from ...masking_utils import create_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
@@ -378,6 +378,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
+@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
@@ -468,6 +469,7 @@ def __init__(self, config: AriaTextConfig, layer_idx: int):
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
+ self.rotary_fn = apply_rotary_pos_emb
def forward(
self,
diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py
index 2428222e0dbe..7f6cceeb8372 100644
--- a/src/transformers/models/bamba/modeling_bamba.py
+++ b/src/transformers/models/bamba/modeling_bamba.py
@@ -370,6 +370,7 @@ def __init__(self, config: BambaConfig, layer_idx: int):
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
+ self.rotary_fn = apply_rotary_pos_emb
def forward(
self,
diff --git a/src/transformers/models/bitnet/modeling_bitnet.py b/src/transformers/models/bitnet/modeling_bitnet.py
index 73597dd98d82..25568ca92365 100644
--- a/src/transformers/models/bitnet/modeling_bitnet.py
+++ b/src/transformers/models/bitnet/modeling_bitnet.py
@@ -27,7 +27,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
-from ...integrations import use_kernel_forward_from_hub
+from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
from ...masking_utils import create_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
@@ -85,6 +85,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
+@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
@@ -175,6 +176,7 @@ def __init__(self, config: BitNetConfig, layer_idx: int):
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
+ self.rotary_fn = apply_rotary_pos_emb
self.attn_sub_norm = BitNetRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py
index 9fc2593d3175..752b693fe7b8 100644
--- a/src/transformers/models/cohere/modeling_cohere.py
+++ b/src/transformers/models/cohere/modeling_cohere.py
@@ -247,6 +247,7 @@ def __init__(self, config: CohereConfig, layer_idx: Optional[int] = None):
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
+ self.rotary_fn = apply_rotary_pos_emb
self.use_qk_norm = config.use_qk_norm
if self.use_qk_norm:
# When sharding the model using Tensor Parallelism, need to be careful to use n_local_heads
diff --git a/src/transformers/models/csm/modeling_csm.py b/src/transformers/models/csm/modeling_csm.py
index 87da76281717..3c58c7f6d45c 100644
--- a/src/transformers/models/csm/modeling_csm.py
+++ b/src/transformers/models/csm/modeling_csm.py
@@ -32,7 +32,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
-from ...integrations import use_kernel_forward_from_hub
+from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
from ...masking_utils import create_causal_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
@@ -206,6 +206,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
+@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
@@ -296,6 +297,7 @@ def __init__(self, config: CsmConfig, layer_idx: int):
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
+ self.rotary_fn = apply_rotary_pos_emb
def forward(
self,
diff --git a/src/transformers/models/cwm/modeling_cwm.py b/src/transformers/models/cwm/modeling_cwm.py
index df9760ed1ba7..53cc8c23a8c6 100644
--- a/src/transformers/models/cwm/modeling_cwm.py
+++ b/src/transformers/models/cwm/modeling_cwm.py
@@ -28,7 +28,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
-from ...integrations import use_kernel_forward_from_hub
+from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
@@ -113,6 +113,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
+@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
@@ -195,6 +196,7 @@ def __init__(self, config: CwmConfig, layer_idx: int):
self.k_proj = torch.nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = torch.nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
+ self.rotary_fn = apply_rotary_pos_emb
self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
def forward(
diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py
index ddf5fce4dfce..92df09947bc4 100644
--- a/src/transformers/models/dbrx/modeling_dbrx.py
+++ b/src/transformers/models/dbrx/modeling_dbrx.py
@@ -29,6 +29,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
+from ...integrations import use_kernel_func_from_hub
from ...masking_utils import create_causal_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
@@ -112,6 +113,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
+@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
diff --git a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py
index cfd8d91dfb9a..4c56277c69dd 100644
--- a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py
+++ b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py
@@ -16,7 +16,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
-from ...integrations import use_kernel_forward_from_hub
+from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
from ...masking_utils import create_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import (
@@ -253,6 +253,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
+@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
diff --git a/src/transformers/models/dia/modeling_dia.py b/src/transformers/models/dia/modeling_dia.py
index 3a0ddf6e3f90..55bbfd96411d 100644
--- a/src/transformers/models/dia/modeling_dia.py
+++ b/src/transformers/models/dia/modeling_dia.py
@@ -27,7 +27,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
-from ...integrations import use_kernel_forward_from_hub
+from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
from ...masking_utils import create_bidirectional_mask, create_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
@@ -200,6 +200,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
+@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py
index 99524915b9f6..b23ee2ed948d 100644
--- a/src/transformers/models/diffllama/modeling_diffllama.py
+++ b/src/transformers/models/diffllama/modeling_diffllama.py
@@ -32,7 +32,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin
-from ...integrations import use_kernel_forward_from_hub
+from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
from ...masking_utils import create_causal_mask
from ...modeling_flash_attention_utils import _flash_attention_forward, flash_attn_supports_top_left_mask
from ...modeling_layers import (
@@ -141,6 +141,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
+@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
diff --git a/src/transformers/models/doge/modeling_doge.py b/src/transformers/models/doge/modeling_doge.py
index b9ebf9856264..a3a94f88df55 100644
--- a/src/transformers/models/doge/modeling_doge.py
+++ b/src/transformers/models/doge/modeling_doge.py
@@ -33,7 +33,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
-from ...integrations import use_kernel_forward_from_hub
+from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
from ...integrations.flex_attention import compile_friendly_flex_attention
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer
@@ -143,6 +143,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
+@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
diff --git a/src/transformers/models/dots1/modeling_dots1.py b/src/transformers/models/dots1/modeling_dots1.py
index b8ae00b6ab60..9092a3533e43 100644
--- a/src/transformers/models/dots1/modeling_dots1.py
+++ b/src/transformers/models/dots1/modeling_dots1.py
@@ -29,7 +29,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
-from ...integrations import use_kernel_forward_from_hub
+from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
@@ -135,6 +135,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
+@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
@@ -226,6 +227,7 @@ def __init__(self, config: Dots1Config, layer_idx: int):
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
+ self.rotary_fn = apply_rotary_pos_emb
self.q_norm = Dots1RMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim!
self.k_norm = Dots1RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape
self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
diff --git a/src/transformers/models/emu3/modeling_emu3.py b/src/transformers/models/emu3/modeling_emu3.py
index 65671913b27f..d3d272aa8fdd 100644
--- a/src/transformers/models/emu3/modeling_emu3.py
+++ b/src/transformers/models/emu3/modeling_emu3.py
@@ -33,7 +33,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
-from ...integrations import use_kernel_forward_from_hub
+from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
from ...masking_utils import create_causal_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
@@ -52,6 +52,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
+@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
@@ -142,6 +143,7 @@ def __init__(self, config: Emu3Config, layer_idx: int):
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
+ self.rotary_fn = apply_rotary_pos_emb
def forward(
self,
diff --git a/src/transformers/models/ernie4_5/modeling_ernie4_5.py b/src/transformers/models/ernie4_5/modeling_ernie4_5.py
index b53ddf923e70..c18f19206b21 100644
--- a/src/transformers/models/ernie4_5/modeling_ernie4_5.py
+++ b/src/transformers/models/ernie4_5/modeling_ernie4_5.py
@@ -221,6 +221,7 @@ def __init__(self, config: Ernie4_5Config, layer_idx: int):
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias)
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias)
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.use_bias)
+ self.rotary_fn = apply_rotary_pos_emb
def forward(
self,
diff --git a/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py b/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py
index ccd05fe26347..75c36992b22d 100644
--- a/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py
+++ b/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py
@@ -244,6 +244,7 @@ def __init__(self, config: Ernie4_5_MoeConfig, layer_idx: int):
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias)
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias)
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.use_bias)
+ self.rotary_fn = apply_rotary_pos_emb
def forward(
self,
diff --git a/src/transformers/models/evolla/modeling_evolla.py b/src/transformers/models/evolla/modeling_evolla.py
index f4d0ce11255f..2260204eb150 100644
--- a/src/transformers/models/evolla/modeling_evolla.py
+++ b/src/transformers/models/evolla/modeling_evolla.py
@@ -31,7 +31,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
-from ...integrations import use_kernel_forward_from_hub
+from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
from ...masking_utils import create_bidirectional_mask, create_causal_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
@@ -1051,6 +1051,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
+@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
@@ -1115,6 +1116,7 @@ def __init__(self, config: EvollaConfig, layer_idx: int):
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
+ self.rotary_fn = apply_rotary_pos_emb
def forward(
self,
diff --git a/src/transformers/models/exaone4/modeling_exaone4.py b/src/transformers/models/exaone4/modeling_exaone4.py
index cb70c9cff142..2287f9f8e88d 100644
--- a/src/transformers/models/exaone4/modeling_exaone4.py
+++ b/src/transformers/models/exaone4/modeling_exaone4.py
@@ -31,7 +31,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
-from ...integrations import use_kernel_forward_from_hub
+from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_layers import (
GenericForQuestionAnswering,
@@ -140,6 +140,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
+@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
diff --git a/src/transformers/models/falcon_h1/modeling_falcon_h1.py b/src/transformers/models/falcon_h1/modeling_falcon_h1.py
index a6fd7a5aba99..1cf19c500737 100644
--- a/src/transformers/models/falcon_h1/modeling_falcon_h1.py
+++ b/src/transformers/models/falcon_h1/modeling_falcon_h1.py
@@ -36,7 +36,7 @@
from ... import initialization as init
from ...cache_utils import Cache
from ...generation import GenerationMixin
-from ...integrations import use_kernel_forward_from_hub
+from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
@@ -295,6 +295,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
+@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
@@ -385,6 +386,7 @@ def __init__(self, config: FalconH1Config, layer_idx: int):
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
+ self.rotary_fn = apply_rotary_pos_emb
self.key_multiplier = config.key_multiplier
def forward(
diff --git a/src/transformers/models/flex_olmo/modeling_flex_olmo.py b/src/transformers/models/flex_olmo/modeling_flex_olmo.py
index b948b420ad63..993a3dae1652 100644
--- a/src/transformers/models/flex_olmo/modeling_flex_olmo.py
+++ b/src/transformers/models/flex_olmo/modeling_flex_olmo.py
@@ -241,6 +241,7 @@ def __init__(self, config: FlexOlmoConfig, layer_idx: Optional[int] = None):
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
+ self.rotary_fn = apply_rotary_pos_emb
self.q_norm = FlexOlmoRMSNorm(config.num_attention_heads * self.head_dim, config.rms_norm_eps)
self.k_norm = FlexOlmoRMSNorm(config.num_key_value_heads * self.head_dim, config.rms_norm_eps)
diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py
index 8834ba8c3564..ae5e25202a1d 100644
--- a/src/transformers/models/gemma/modeling_gemma.py
+++ b/src/transformers/models/gemma/modeling_gemma.py
@@ -29,6 +29,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
+from ...integrations import use_kernel_func_from_hub
from ...masking_utils import create_causal_mask
from ...modeling_layers import (
GenericForSequenceClassification,
@@ -152,6 +153,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
+@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
@@ -242,6 +244,7 @@ def __init__(self, config: GemmaConfig, layer_idx: int):
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
+ self.rotary_fn = apply_rotary_pos_emb
def forward(
self,
diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py
index 69e486032107..d6362634b1e3 100644
--- a/src/transformers/models/gemma2/modeling_gemma2.py
+++ b/src/transformers/models/gemma2/modeling_gemma2.py
@@ -29,6 +29,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
+from ...integrations import use_kernel_func_from_hub
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import (
@@ -156,6 +157,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
+@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
@@ -256,6 +258,7 @@ def __init__(self, config: Gemma2Config, layer_idx: int):
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
+ self.rotary_fn = apply_rotary_pos_emb
self.attn_logit_softcapping = self.config.attn_logit_softcapping
self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py
index 8e93ef9231b5..253931d4eeb1 100644
--- a/src/transformers/models/gemma3/modeling_gemma3.py
+++ b/src/transformers/models/gemma3/modeling_gemma3.py
@@ -31,6 +31,7 @@
from ...cache_utils import Cache, DynamicCache
from ...configuration_utils import PreTrainedConfig
from ...generation import GenerationMixin
+from ...integrations import use_kernel_func_from_hub
from ...masking_utils import create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer
@@ -230,6 +231,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
+@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
@@ -330,6 +332,7 @@ def __init__(self, config: Gemma3TextConfig, layer_idx: int):
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
+ self.rotary_fn = apply_rotary_pos_emb
self.attn_logit_softcapping = self.config.attn_logit_softcapping
self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
self.is_sliding = self.layer_type == "sliding_attention"
diff --git a/src/transformers/models/gemma3n/modeling_gemma3n.py b/src/transformers/models/gemma3n/modeling_gemma3n.py
index 3be6ad39ddca..3807311ff674 100644
--- a/src/transformers/models/gemma3n/modeling_gemma3n.py
+++ b/src/transformers/models/gemma3n/modeling_gemma3n.py
@@ -1254,6 +1254,7 @@ def __init__(self, config: Gemma3nTextConfig, layer_idx: int):
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
+ self.rotary_fn = apply_rotary_pos_emb
self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
self.is_sliding = self.layer_type == "sliding_attention"
@@ -1475,6 +1476,7 @@ def __init__(self, config: Gemma3nConfig, layer_idx: int):
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
+ self.rotary_fn = apply_rotary_pos_emb
self.attn_logit_softcapping = self.config.attn_logit_softcapping
self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py
index 8a508e2de54c..aefb879dc659 100644
--- a/src/transformers/models/glm/modeling_glm.py
+++ b/src/transformers/models/glm/modeling_glm.py
@@ -239,6 +239,7 @@ def __init__(self, config: GlmConfig, layer_idx: Optional[int] = None):
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
+ self.rotary_fn = apply_rotary_pos_emb
def forward(
self,
diff --git a/src/transformers/models/glm4/modeling_glm4.py b/src/transformers/models/glm4/modeling_glm4.py
index c982c36f9aab..3bfec8706d8f 100644
--- a/src/transformers/models/glm4/modeling_glm4.py
+++ b/src/transformers/models/glm4/modeling_glm4.py
@@ -221,6 +221,7 @@ def __init__(self, config: Glm4Config, layer_idx: Optional[int] = None):
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
+ self.rotary_fn = apply_rotary_pos_emb
def forward(
self,
diff --git a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py
index 9537d9018838..74016aa98e7e 100644
--- a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py
+++ b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py
@@ -204,6 +204,48 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+
+ # Interleave them instead of usual shape
+ cos = cos[..., : cos.shape[-1] // 2].repeat_interleave(2, dim=-1)
+ sin = sin[..., : sin.shape[-1] // 2].repeat_interleave(2, dim=-1)
+
+ # Keep half or full tensor for later concatenation
+ rotary_dim = cos.shape[-1]
+ q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
+ k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
+
+ # Apply rotary embeddings on the first half or full tensor
+ q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
+ k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
+
+ # Concatenate back to full shape
+ q_embed = torch.cat([q_embed, q_pass], dim=-1)
+ k_embed = torch.cat([k_embed, k_pass], dim=-1)
+ return q_embed, k_embed
+
+
def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
"""Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/).
@@ -280,6 +322,7 @@ def __init__(self, config: Glm4vMoeTextConfig, layer_idx: Optional[int] = None):
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
+ self.rotary_fn = apply_rotary_pos_emb
self.rope_parameters = config.rope_parameters
def forward(
diff --git a/src/transformers/models/gpt_oss/modeling_gpt_oss.py b/src/transformers/models/gpt_oss/modeling_gpt_oss.py
index 8e1ce9df0b97..5e1173d823d0 100644
--- a/src/transformers/models/gpt_oss/modeling_gpt_oss.py
+++ b/src/transformers/models/gpt_oss/modeling_gpt_oss.py
@@ -325,6 +325,7 @@ def __init__(self, config: GptOssConfig, layer_idx: int):
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
+ self.rotary_fn = apply_rotary_pos_emb
self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
self.sinks = nn.Parameter(torch.empty(config.num_attention_heads))
diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py
index 42de2e0724f3..ae2d017d53bf 100644
--- a/src/transformers/models/granite/modeling_granite.py
+++ b/src/transformers/models/granite/modeling_granite.py
@@ -28,7 +28,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
-from ...integrations import use_kernel_forward_from_hub
+from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
from ...masking_utils import create_causal_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
@@ -50,6 +50,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
+@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
@@ -140,6 +141,7 @@ def __init__(self, config: GraniteConfig, layer_idx: Optional[int] = None):
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
+ self.rotary_fn = apply_rotary_pos_emb
def forward(
self,
diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py
index f722ad416a2f..d50147d5c686 100644
--- a/src/transformers/models/granitemoe/modeling_granitemoe.py
+++ b/src/transformers/models/granitemoe/modeling_granitemoe.py
@@ -30,7 +30,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
-from ...integrations import use_kernel_forward_from_hub
+from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
from ...masking_utils import create_causal_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
@@ -272,6 +272,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
+@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
@@ -362,6 +363,7 @@ def __init__(self, config: GraniteMoeConfig, layer_idx: int):
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
+ self.rotary_fn = apply_rotary_pos_emb
def forward(
self,
diff --git a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py
index c9e7245956f3..9534b45c0624 100644
--- a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py
+++ b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py
@@ -31,7 +31,7 @@
from ... import initialization as init
from ...cache_utils import Cache
from ...generation import GenerationMixin
-from ...integrations import use_kernel_forward_from_hub
+from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
from ...masking_utils import create_causal_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, MoeCausalLMOutputWithPast, MoeModelOutputWithPast
@@ -59,6 +59,41 @@
logger = logging.get_logger(__name__)
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+@use_kernel_func_from_hub("rotary_pos_emb")
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
@@ -122,6 +157,7 @@ def __init__(self, config: GraniteMoeHybridConfig, layer_idx: int):
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
+ self.rotary_fn = apply_rotary_pos_emb
def forward(
self,
diff --git a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py
index 606a59390e6e..78bcdd03c43e 100644
--- a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py
+++ b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py
@@ -30,7 +30,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
-from ...integrations import use_kernel_forward_from_hub
+from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
from ...masking_utils import create_causal_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
@@ -262,6 +262,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
+@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
@@ -352,6 +353,7 @@ def __init__(self, config: GraniteMoeSharedConfig, layer_idx: int):
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
+ self.rotary_fn = apply_rotary_pos_emb
def forward(
self,
diff --git a/src/transformers/models/helium/modeling_helium.py b/src/transformers/models/helium/modeling_helium.py
index e616da3cd07b..97a0f89b212b 100644
--- a/src/transformers/models/helium/modeling_helium.py
+++ b/src/transformers/models/helium/modeling_helium.py
@@ -243,6 +243,7 @@ def __init__(self, config: HeliumConfig, layer_idx: Optional[int] = None):
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
+ self.rotary_fn = apply_rotary_pos_emb
def forward(
self,
diff --git a/src/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py b/src/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py
index 4d184a0b1982..28899e306f4c 100644
--- a/src/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py
+++ b/src/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py
@@ -30,7 +30,7 @@
from ...activations import ACT2FN
from ...cache_utils import DynamicCache
from ...generation import GenerationMixin
-from ...integrations import use_kernel_forward_from_hub
+from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
from ...masking_utils import create_causal_mask
from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
@@ -87,6 +87,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
+@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
@@ -177,6 +178,7 @@ def __init__(self, config: HunYuanDenseV1Config, layer_idx: int):
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
+ self.rotary_fn = apply_rotary_pos_emb
self.query_layernorm = HunYuanDenseV1RMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.key_layernorm = HunYuanDenseV1RMSNorm(self.head_dim, eps=config.rms_norm_eps)
diff --git a/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py b/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py
index 281a50a9e2cc..dda4366f0d4d 100644
--- a/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py
+++ b/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py
@@ -30,7 +30,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
-from ...integrations import use_kernel_forward_from_hub
+from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
from ...masking_utils import create_causal_mask
from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
@@ -86,6 +86,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
+@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
@@ -176,6 +177,7 @@ def __init__(self, config: HunYuanMoEV1Config, layer_idx: int):
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
+ self.rotary_fn = apply_rotary_pos_emb
self.query_layernorm = HunYuanMoEV1RMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.key_layernorm = HunYuanMoEV1RMSNorm(self.head_dim, eps=config.rms_norm_eps)
diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py
index fbda366fc319..47673a3afab2 100755
--- a/src/transformers/models/jamba/modeling_jamba.py
+++ b/src/transformers/models/jamba/modeling_jamba.py
@@ -32,7 +32,7 @@
from ... import initialization as init
from ...activations import ACT2FN
from ...generation import GenerationMixin
-from ...integrations import use_kernel_forward_from_hub
+from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
from ...masking_utils import create_causal_mask
from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer
from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
@@ -175,6 +175,41 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
return self.key_cache[layer_idx].shape[-2]
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+@use_kernel_func_from_hub("rotary_pos_emb")
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
@@ -229,6 +264,7 @@ def __init__(self, config: JambaConfig, layer_idx: int):
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
+ self.rotary_fn = apply_rotary_pos_emb
def forward(
self,
diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py
index b102a111e10f..1ef99fc46d7f 100644
--- a/src/transformers/models/jetmoe/modeling_jetmoe.py
+++ b/src/transformers/models/jetmoe/modeling_jetmoe.py
@@ -30,7 +30,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
-from ...integrations import use_kernel_forward_from_hub
+from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
from ...masking_utils import create_causal_mask
from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer
from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
@@ -366,6 +366,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
+@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
diff --git a/src/transformers/models/lfm2/modeling_lfm2.py b/src/transformers/models/lfm2/modeling_lfm2.py
index f1d639d16bbd..aa66a2e2d80c 100644
--- a/src/transformers/models/lfm2/modeling_lfm2.py
+++ b/src/transformers/models/lfm2/modeling_lfm2.py
@@ -26,7 +26,7 @@
from ...cache_utils import Cache
from ...generation import GenerationMixin
-from ...integrations import use_kernel_forward_from_hub
+from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
from ...masking_utils import create_causal_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
@@ -292,6 +292,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
+@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
@@ -371,6 +372,7 @@ def __init__(self, config: Lfm2Config, layer_idx: int):
self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
+ self.rotary_fn = apply_rotary_pos_emb
self.out_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
self.q_layernorm = Lfm2RMSNorm(self.head_dim, eps=config.norm_eps)
self.k_layernorm = Lfm2RMSNorm(self.head_dim, eps=config.norm_eps)
diff --git a/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py b/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py
index 73b9c4a8fde0..55dcde8b07e9 100644
--- a/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py
+++ b/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py
@@ -28,7 +28,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache
from ...generation import GenerationMixin
-from ...integrations import use_kernel_forward_from_hub
+from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
from ...masking_utils import create_causal_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, MoeModelOutputWithPast
@@ -363,6 +363,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
+@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
@@ -442,6 +443,7 @@ def __init__(self, config: Lfm2MoeConfig, layer_idx: int):
self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
+ self.rotary_fn = apply_rotary_pos_emb
self.out_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
self.q_layernorm = Lfm2MoeRMSNorm(self.head_dim, eps=config.norm_eps)
self.k_layernorm = Lfm2MoeRMSNorm(self.head_dim, eps=config.norm_eps)
diff --git a/src/transformers/models/lightglue/modeling_lightglue.py b/src/transformers/models/lightglue/modeling_lightglue.py
index 0d5044c5a40a..e684b11130d8 100644
--- a/src/transformers/models/lightglue/modeling_lightglue.py
+++ b/src/transformers/models/lightglue/modeling_lightglue.py
@@ -199,6 +199,7 @@ def __init__(self, config: LightGlueConfig, layer_idx: int):
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
+ self.rotary_fn = apply_rotary_pos_emb
def forward(
self,
diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py
index e3adac5d117d..eb938b6cbf7e 100644
--- a/src/transformers/models/llama/modeling_llama.py
+++ b/src/transformers/models/llama/modeling_llama.py
@@ -26,7 +26,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
-from ...integrations import use_kernel_forward_from_hub
+from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
from ...masking_utils import create_causal_mask
from ...modeling_layers import (
GenericForQuestionAnswering,
@@ -142,6 +142,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
+@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
@@ -248,6 +249,7 @@ def __init__(self, config: LlamaConfig, layer_idx: int):
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
+ self.rotary_fn = apply_rotary_pos_emb
def forward(
self,
diff --git a/src/transformers/models/minimax/modeling_minimax.py b/src/transformers/models/minimax/modeling_minimax.py
index 004ed68cef23..8b8f8c9adb3a 100644
--- a/src/transformers/models/minimax/modeling_minimax.py
+++ b/src/transformers/models/minimax/modeling_minimax.py
@@ -31,7 +31,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
-from ...integrations import use_kernel_forward_from_hub
+from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import (
@@ -326,6 +326,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
+@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
@@ -407,6 +408,7 @@ def __init__(self, config: MiniMaxConfig, layer_idx: int):
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
+ self.rotary_fn = apply_rotary_pos_emb
def forward(
self,
diff --git a/src/transformers/models/ministral/modeling_ministral.py b/src/transformers/models/ministral/modeling_ministral.py
index b1c8555fd96b..ccf6c9abcda5 100644
--- a/src/transformers/models/ministral/modeling_ministral.py
+++ b/src/transformers/models/ministral/modeling_ministral.py
@@ -13,7 +13,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
-from ...integrations import use_kernel_forward_from_hub
+from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import (
@@ -54,6 +54,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
+@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
@@ -137,6 +138,7 @@ def __init__(self, config, layer_idx: int):
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
+ self.rotary_fn = apply_rotary_pos_emb
self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
def forward(
diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py
index 60c7e2d49eed..43dc5c18a90b 100644
--- a/src/transformers/models/mistral/modeling_mistral.py
+++ b/src/transformers/models/mistral/modeling_mistral.py
@@ -15,7 +15,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
-from ...integrations import use_kernel_forward_from_hub
+from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import (
@@ -55,6 +55,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
+@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
@@ -136,6 +137,7 @@ def __init__(self, config: MistralConfig, layer_idx: int):
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
+ self.rotary_fn = apply_rotary_pos_emb
def forward(
self,
diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py
index 1faff1f4dcea..95b236dadce6 100644
--- a/src/transformers/models/mixtral/modeling_mixtral.py
+++ b/src/transformers/models/mixtral/modeling_mixtral.py
@@ -37,7 +37,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
-from ...integrations import use_kernel_forward_from_hub
+from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import (
@@ -225,6 +225,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
+@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
@@ -306,6 +307,7 @@ def __init__(self, config: MixtralConfig, layer_idx: int):
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
+ self.rotary_fn = apply_rotary_pos_emb
def forward(
self,
diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py
index 8069f2bec2ff..4869da818e01 100644
--- a/src/transformers/models/modernbert/modeling_modernbert.py
+++ b/src/transformers/models/modernbert/modeling_modernbert.py
@@ -31,6 +31,7 @@
from ... import initialization as init
from ...activations import ACT2FN
+from ...integrations import use_kernel_func_from_hub
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
@@ -331,6 +332,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
+@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
diff --git a/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py b/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py
index 7564e375716b..5059e623c9dd 100644
--- a/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py
+++ b/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py
@@ -31,6 +31,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
+from ...integrations import use_kernel_func_from_hub
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
@@ -183,6 +184,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
+@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
diff --git a/src/transformers/models/moonshine/modeling_moonshine.py b/src/transformers/models/moonshine/modeling_moonshine.py
index 373e1db4a217..87893284ec4e 100644
--- a/src/transformers/models/moonshine/modeling_moonshine.py
+++ b/src/transformers/models/moonshine/modeling_moonshine.py
@@ -264,6 +264,7 @@ def __init__(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
+ self.rotary_fn = apply_rotary_pos_emb
# Pad head dimension to the next specified multiple.
if self.config.pad_head_dim_to_multiple_of is not None:
diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py
index 2ba7a25f71b5..1dcaa6247155 100644
--- a/src/transformers/models/olmo/modeling_olmo.py
+++ b/src/transformers/models/olmo/modeling_olmo.py
@@ -237,6 +237,7 @@ def __init__(self, config: OlmoConfig, layer_idx: int):
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
+ self.rotary_fn = apply_rotary_pos_emb
def forward(
self,
diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py
index 44e3592157af..aa6296f344f1 100644
--- a/src/transformers/models/olmo2/modeling_olmo2.py
+++ b/src/transformers/models/olmo2/modeling_olmo2.py
@@ -230,6 +230,7 @@ def __init__(self, config: Olmo2Config, layer_idx: Optional[int] = None):
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
+ self.rotary_fn = apply_rotary_pos_emb
self.q_norm = Olmo2RMSNorm(config.num_attention_heads * self.head_dim, config.rms_norm_eps)
self.k_norm = Olmo2RMSNorm(config.num_key_value_heads * self.head_dim, config.rms_norm_eps)
diff --git a/src/transformers/models/olmo3/modeling_olmo3.py b/src/transformers/models/olmo3/modeling_olmo3.py
index d49570982f48..003e691e9c82 100644
--- a/src/transformers/models/olmo3/modeling_olmo3.py
+++ b/src/transformers/models/olmo3/modeling_olmo3.py
@@ -161,6 +161,7 @@ def __init__(self, config: Olmo3Config, layer_idx: int):
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
+ self.rotary_fn = apply_rotary_pos_emb
self.q_norm = Olmo3RMSNorm(config.num_attention_heads * self.head_dim, config.rms_norm_eps)
self.k_norm = Olmo3RMSNorm(config.num_key_value_heads * self.head_dim, config.rms_norm_eps)
assert config.layer_types is not None
diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py
index f078518e0c1f..d04cd421d441 100644
--- a/src/transformers/models/olmoe/modeling_olmoe.py
+++ b/src/transformers/models/olmoe/modeling_olmoe.py
@@ -27,7 +27,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
-from ...integrations import use_kernel_forward_from_hub
+from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
from ...masking_utils import create_causal_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
@@ -148,6 +148,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
+@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
@@ -238,6 +239,7 @@ def __init__(self, config: OlmoeConfig, layer_idx: Optional[int] = None):
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
+ self.rotary_fn = apply_rotary_pos_emb
self.q_norm = OlmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.k_norm = OlmoeRMSNorm(
(config.hidden_size // config.num_attention_heads) * config.num_key_value_heads, eps=config.rms_norm_eps
diff --git a/src/transformers/models/parakeet/modeling_parakeet.py b/src/transformers/models/parakeet/modeling_parakeet.py
index 079307547fe5..d11d56d3a11c 100644
--- a/src/transformers/models/parakeet/modeling_parakeet.py
+++ b/src/transformers/models/parakeet/modeling_parakeet.py
@@ -29,6 +29,7 @@
from ... import initialization as init
from ...activations import ACT2FN
+from ...integrations import use_kernel_func_from_hub
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, CausalLMOutput
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
@@ -182,6 +183,41 @@ def forward(self, hidden_states, attention_mask=None):
return hidden_states.transpose(1, 2)
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+@use_kernel_func_from_hub("rotary_pos_emb")
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
@@ -245,6 +281,7 @@ def __init__(self, config: ParakeetEncoderConfig, layer_idx: int):
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
+ self.rotary_fn = apply_rotary_pos_emb
# W_{k,R} projection
self.relative_k_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
# global content bias
diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py
index 4a1530b78564..eec8794a3d65 100644
--- a/src/transformers/models/phi/modeling_phi.py
+++ b/src/transformers/models/phi/modeling_phi.py
@@ -13,6 +13,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
+from ...integrations import use_kernel_func_from_hub
from ...masking_utils import create_causal_mask
from ...modeling_layers import (
GenericForSequenceClassification,
@@ -105,6 +106,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
+@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
@@ -185,6 +187,7 @@ def __init__(self, config: PhiConfig, layer_idx: int):
self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True)
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True)
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True)
+ self.rotary_fn = apply_rotary_pos_emb
self.dense = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=True)
self.rotary_ndims = int(self.head_dim * config.partial_rotary_factor)
self.qk_layernorm = config.qk_layernorm
diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py
index 12e41214094d..fc7c733f3271 100644
--- a/src/transformers/models/phimoe/modeling_phimoe.py
+++ b/src/transformers/models/phimoe/modeling_phimoe.py
@@ -30,7 +30,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
-from ...integrations import use_kernel_forward_from_hub
+from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer
from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
@@ -128,6 +128,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
+@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
@@ -218,6 +219,7 @@ def __init__(self, config: PhimoeConfig, layer_idx: int):
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
+ self.rotary_fn = apply_rotary_pos_emb
def forward(
self,
diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py
index 1215f3677603..060683cfb972 100644
--- a/src/transformers/models/qwen2/modeling_qwen2.py
+++ b/src/transformers/models/qwen2/modeling_qwen2.py
@@ -13,7 +13,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
-from ...integrations import use_kernel_forward_from_hub
+from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import (
@@ -119,6 +119,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
+@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
@@ -201,6 +202,7 @@ def __init__(self, config: Qwen2Config, layer_idx: int):
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True)
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True)
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
+ self.rotary_fn = apply_rotary_pos_emb
self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
def forward(
diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
index 8bda140d3cdb..389bf016243a 100644
--- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
+++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
@@ -35,7 +35,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
-from ...integrations import use_kernel_forward_from_hub
+from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_layers import (
GenericForQuestionAnswering,
@@ -161,6 +161,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
+@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
@@ -243,6 +244,7 @@ def __init__(self, config: Qwen2MoeConfig, layer_idx: int):
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.qkv_bias)
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.qkv_bias)
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
+ self.rotary_fn = apply_rotary_pos_emb
if self.config.layer_types[layer_idx] == "sliding_attention":
self.sliding_window = config.sliding_window
diff --git a/src/transformers/models/qwen3/modeling_qwen3.py b/src/transformers/models/qwen3/modeling_qwen3.py
index 5f0f8974eb0a..9b6c41662929 100644
--- a/src/transformers/models/qwen3/modeling_qwen3.py
+++ b/src/transformers/models/qwen3/modeling_qwen3.py
@@ -28,7 +28,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
-from ...integrations import use_kernel_forward_from_hub
+from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import (
@@ -155,6 +155,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
+@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
@@ -246,6 +247,7 @@ def __init__(self, config: Qwen3Config, layer_idx: int):
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
+ self.rotary_fn = apply_rotary_pos_emb
self.q_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim!
self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape
self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
diff --git a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py
index 477694d5fb2b..caf7e26a39fe 100644
--- a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py
+++ b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py
@@ -30,7 +30,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
-from ...integrations import use_kernel_forward_from_hub
+from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import (
@@ -55,6 +55,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
+@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
@@ -145,6 +146,7 @@ def __init__(self, config: Qwen3MoeConfig, layer_idx: int):
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
+ self.rotary_fn = apply_rotary_pos_emb
self.q_norm = Qwen3MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim!
self.k_norm = Qwen3MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape
self.sliding_window = getattr(config, "sliding_window", None)
diff --git a/src/transformers/models/qwen3_next/modeling_qwen3_next.py b/src/transformers/models/qwen3_next/modeling_qwen3_next.py
index 362c8fab007f..27581014e9f7 100644
--- a/src/transformers/models/qwen3_next/modeling_qwen3_next.py
+++ b/src/transformers/models/qwen3_next/modeling_qwen3_next.py
@@ -371,6 +371,7 @@ def __init__(self, config: Qwen3NextConfig, layer_idx: int):
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
+ self.rotary_fn = apply_rotary_pos_emb
self.q_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim!
self.k_norm = Qwen3NextRMSNorm(
self.head_dim, eps=config.rms_norm_eps
diff --git a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py
index 1be0487cea98..73709df5d0bc 100644
--- a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py
+++ b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py
@@ -35,7 +35,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
-from ...integrations import use_kernel_forward_from_hub
+from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
@@ -1415,6 +1415,7 @@ def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
@@ -1467,6 +1468,7 @@ def __init__(self, config, layer_idx):
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
+ self.rotary_fn = apply_rotary_pos_emb
self.q_norm = Qwen3OmniMoeThinkerTextRMSNorm(
self.head_dim, eps=config.rms_norm_eps
) # unlike olmo, only on the head dim!
@@ -2348,6 +2350,7 @@ def __init__(self, config: Qwen3OmniMoeConfig, layer_idx: int):
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
+ self.rotary_fn = apply_rotary_pos_emb
self.q_norm = Qwen3OmniMoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim!
self.k_norm = Qwen3OmniMoeRMSNorm(
self.head_dim, eps=config.rms_norm_eps
@@ -3377,6 +3380,7 @@ def __init__(self, config: Qwen3OmniMoeCode2WavConfig, layer_idx):
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
+ self.rotary_fn = apply_rotary_pos_emb
self.q_norm = nn.Identity()
self.k_norm = nn.Identity()
self.sliding_window = config.sliding_window
diff --git a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py
index aab768b1cf1c..9922ccb09bb6 100644
--- a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py
+++ b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py
@@ -30,7 +30,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
-from ...integrations import use_kernel_forward_from_hub
+from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
from ...masking_utils import create_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
@@ -385,6 +385,7 @@ def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
@@ -438,6 +439,7 @@ def __init__(self, config: Qwen3VLTextConfig, layer_idx: int):
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
+ self.rotary_fn = apply_rotary_pos_emb
self.q_norm = Qwen3VLTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim!
self.k_norm = Qwen3VLTextRMSNorm(
self.head_dim, eps=config.rms_norm_eps
diff --git a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py
index eab677bce4fe..28e2c85f156c 100644
--- a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py
+++ b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py
@@ -31,7 +31,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
-from ...integrations import use_kernel_forward_from_hub
+from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
from ...masking_utils import create_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
@@ -198,6 +198,7 @@ def eager_attention_forward(
return attn_output, attn_weights
+@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
@@ -251,6 +252,7 @@ def __init__(self, config: Qwen3VLMoeTextConfig, layer_idx: int):
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
+ self.rotary_fn = apply_rotary_pos_emb
self.q_norm = Qwen3VLMoeTextRMSNorm(
self.head_dim, eps=config.rms_norm_eps
) # unlike olmo, only on the head dim!
diff --git a/src/transformers/models/seed_oss/modeling_seed_oss.py b/src/transformers/models/seed_oss/modeling_seed_oss.py
index 682193ca8d51..877d5eaa4fc0 100644
--- a/src/transformers/models/seed_oss/modeling_seed_oss.py
+++ b/src/transformers/models/seed_oss/modeling_seed_oss.py
@@ -27,7 +27,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
-from ...integrations import use_kernel_forward_from_hub
+from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
from ...masking_utils import create_causal_mask
from ...modeling_layers import (
GenericForQuestionAnswering,
@@ -90,6 +90,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
+@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
diff --git a/src/transformers/models/smollm3/modeling_smollm3.py b/src/transformers/models/smollm3/modeling_smollm3.py
index e23d4993e84c..c59906f63692 100644
--- a/src/transformers/models/smollm3/modeling_smollm3.py
+++ b/src/transformers/models/smollm3/modeling_smollm3.py
@@ -28,7 +28,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
-from ...integrations import use_kernel_forward_from_hub
+from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import (
@@ -118,6 +118,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
+@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
@@ -208,6 +209,7 @@ def __init__(self, config: SmolLM3Config, layer_idx: int):
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
+ self.rotary_fn = apply_rotary_pos_emb
self.use_rope = config.no_rope_layers[layer_idx]
self.sliding_window = (
diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py
index 042033fe3565..c013e97f2169 100644
--- a/src/transformers/models/starcoder2/modeling_starcoder2.py
+++ b/src/transformers/models/starcoder2/modeling_starcoder2.py
@@ -35,6 +35,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
+from ...integrations import use_kernel_func_from_hub
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import (
@@ -74,6 +75,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
+@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
@@ -155,6 +157,7 @@ def __init__(self, config: Starcoder2Config, layer_idx: Optional[int] = None):
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias)
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias)
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.use_bias)
+ self.rotary_fn = apply_rotary_pos_emb
self.residual_dropout = config.residual_dropout
def forward(
diff --git a/src/transformers/models/t5gemma/modeling_t5gemma.py b/src/transformers/models/t5gemma/modeling_t5gemma.py
index ac9b64929280..5c0410655ee4 100644
--- a/src/transformers/models/t5gemma/modeling_t5gemma.py
+++ b/src/transformers/models/t5gemma/modeling_t5gemma.py
@@ -29,6 +29,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
from ...generation import GenerationMixin
+from ...integrations import use_kernel_func_from_hub
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
@@ -162,6 +163,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
+@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
@@ -263,6 +265,7 @@ def __init__(self, config: T5GemmaModuleConfig, layer_idx: int):
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
+ self.rotary_fn = apply_rotary_pos_emb
self.attn_logit_softcapping = self.config.attn_logit_softcapping
self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
@@ -338,6 +341,7 @@ def __init__(self, config: T5GemmaModuleConfig, layer_idx: int):
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
+ self.rotary_fn = apply_rotary_pos_emb
self.attn_logit_softcapping = self.config.attn_logit_softcapping
if config.cross_attention_hidden_size is None:
diff --git a/src/transformers/models/vaultgemma/modeling_vaultgemma.py b/src/transformers/models/vaultgemma/modeling_vaultgemma.py
index ee36d7519a53..ad93f819fa0e 100644
--- a/src/transformers/models/vaultgemma/modeling_vaultgemma.py
+++ b/src/transformers/models/vaultgemma/modeling_vaultgemma.py
@@ -29,6 +29,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
+from ...integrations import use_kernel_func_from_hub
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
@@ -87,6 +88,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
+@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
@@ -187,6 +189,7 @@ def __init__(self, config: VaultGemmaConfig, layer_idx: int):
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
+ self.rotary_fn = apply_rotary_pos_emb
self.attn_logit_softcapping = self.config.attn_logit_softcapping
self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py
index 5b5f532cebf7..9ade9d1dfc8d 100644
--- a/src/transformers/models/zamba2/modeling_zamba2.py
+++ b/src/transformers/models/zamba2/modeling_zamba2.py
@@ -32,6 +32,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache
from ...generation import GenerationMixin
+from ...integrations import use_kernel_func_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
@@ -316,6 +317,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
+@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.