|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import sys |
|
|
from kernels_benchmark_tools import KernelTypeEnum, run_benchmark |
|
|
|
|
|
|
|
|
def apply_rotary_torch(x1, x2, cos, sin, conj=False): |
|
|
"""Reference rotary implementation.""" |
|
|
if not conj: |
|
|
out1 = x1 * cos - x2 * sin |
|
|
out2 = x1 * sin + x2 * cos |
|
|
else: |
|
|
out1 = x1 * cos + x2 * sin |
|
|
out2 = -x1 * sin + x2 * cos |
|
|
return out1, out2 |
|
|
|
|
|
|
|
|
def torch_rotary(query, key, cos, sin, conj=False): |
|
|
rotary_dim = cos.shape[-1] |
|
|
|
|
|
|
|
|
q_out = query.clone() |
|
|
k_out = key.clone() |
|
|
|
|
|
|
|
|
q1 = q_out[..., :rotary_dim] |
|
|
q2 = q_out[..., rotary_dim : 2 * rotary_dim] |
|
|
q_out_1, q_out_2 = apply_rotary_torch(q1, q2, cos, sin, conj) |
|
|
q_out[..., :rotary_dim] = q_out_1 |
|
|
q_out[..., rotary_dim : 2 * rotary_dim] = q_out_2 |
|
|
|
|
|
|
|
|
k1 = k_out[..., :rotary_dim] |
|
|
k2 = k_out[..., rotary_dim : 2 * rotary_dim] |
|
|
k_out_1, k_out_2 = apply_rotary_torch(k1, k2, cos, sin, conj) |
|
|
k_out[..., :rotary_dim] = k_out_1 |
|
|
k_out[..., rotary_dim : 2 * rotary_dim] = k_out_2 |
|
|
|
|
|
return q_out, k_out |
|
|
|
|
|
|
|
|
run_benchmark( |
|
|
kernel_type=KernelTypeEnum.ROTARY, |
|
|
impl_name="torch_eager", |
|
|
impl_tags={"family": "pytorch", "backend": "eager"}, |
|
|
impl_func=torch_rotary, |
|
|
) |