Kernels
danieldk HF Staff commited on
Commit
e363fa0
·
verified ·
1 Parent(s): 3ca0e2f

Build uploaded using `kernels`.

Browse files
.gitattributes CHANGED
@@ -38,3 +38,4 @@ build/torch29-cu130-x86_64-windows/rotary/_rotary_a793e44.pyd filter=lfs diff=lf
38
  build/torch210-cu128-x86_64-windows/rotary/_rotary_66b961a.pyd filter=lfs diff=lfs merge=lfs -text
39
  build/torch29-xpu20252-x86_64-windows/rotary/_rotary_66b961a.pyd filter=lfs diff=lfs merge=lfs -text
40
  build/torch210-cu128-x86_64-windows/rotary/_rotary_9f63cc2.pyd filter=lfs diff=lfs merge=lfs -text
 
 
38
  build/torch210-cu128-x86_64-windows/rotary/_rotary_66b961a.pyd filter=lfs diff=lfs merge=lfs -text
39
  build/torch29-xpu20252-x86_64-windows/rotary/_rotary_66b961a.pyd filter=lfs diff=lfs merge=lfs -text
40
  build/torch210-cu128-x86_64-windows/rotary/_rotary_9f63cc2.pyd filter=lfs diff=lfs merge=lfs -text
41
+ build/torch210-xpu20253-x86_64-windows/rotary/_rotary_9f63cc2.pyd filter=lfs diff=lfs merge=lfs -text
build/torch210-xpu20253-x86_64-windows/metadata.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "version": 1,
3
+ "python-depends": []
4
+ }
build/torch210-xpu20253-x86_64-windows/rotary/__init__.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+ import torch
3
+
4
+ from ._ops import ops
5
+
6
+
7
+ def apply_rotary(
8
+ x1: torch.Tensor,
9
+ x2: torch.Tensor,
10
+ cos: torch.Tensor,
11
+ sin: torch.Tensor,
12
+ out1: torch.Tensor,
13
+ out2: torch.Tensor,
14
+ conj: bool,
15
+ ) -> None:
16
+ ops.apply_rotary(x1, x2, cos, sin, out1, out2, conj)
17
+
18
+
19
+ def apply_rotary_transformers(
20
+ q: torch.Tensor,
21
+ k: torch.Tensor,
22
+ cos: torch.Tensor,
23
+ sin: torch.Tensor,
24
+ unsqueeze_dim: int = 1,
25
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
26
+ """
27
+ Rotary kernel implementation wrapper
28
+ Adapts rotary kernel implementation to match transformers apply_rotary_pos_emb signature
29
+ """
30
+ cos = cos.unsqueeze(unsqueeze_dim)
31
+ sin = sin.unsqueeze(unsqueeze_dim)
32
+
33
+ q_rotated = q.clone()
34
+ k_rotated = k.clone()
35
+
36
+ # Get half dimension for rotation
37
+ half_dim = q.shape[-1] // 2
38
+ q1 = q_rotated[..., :half_dim]
39
+ q2 = q_rotated[..., half_dim:]
40
+ k1 = k_rotated[..., :half_dim]
41
+ k2 = k_rotated[..., half_dim:]
42
+ if cos.shape[-1] != half_dim:
43
+ # Trim cos/sin to match half_dim
44
+ cos = cos[..., :half_dim]
45
+ sin = sin[..., :half_dim]
46
+
47
+ apply_rotary(q1, q2, cos, sin, q1, q2, False)
48
+ apply_rotary(k1, k2, cos, sin, k1, k2, False)
49
+ return q_rotated, k_rotated
50
+
51
+
52
+ __all__ = ["apply_rotary", "apply_rotary_transformers"]
build/torch210-xpu20253-x86_64-windows/rotary/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _rotary_9f63cc2
3
+ ops = torch.ops._rotary_9f63cc2
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_rotary_9f63cc2::{op_name}"
build/torch210-xpu20253-x86_64-windows/rotary/_rotary_9f63cc2.pyd ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:658a0baaf353b654aa25e2d83b268cda2130354dcbd1adb9101c2997a39408b3
3
+ size 202240