|
|
|
|
|
|
| import torch
|
|
|
|
|
| def apply_rotary_emb_qk_real(
|
| xqk: torch.Tensor,
|
| freqs_cos: torch.Tensor,
|
| freqs_sin: torch.Tensor,
|
| ) -> torch.Tensor:
|
| """
|
| Apply rotary embeddings to input tensors using the given frequency tensor without complex numbers.
|
|
|
| Args:
|
| xqk (torch.Tensor): Query and/or Key tensors to apply rotary embeddings. Shape: (B, S, *, num_heads, D)
|
| Can be either just query or just key, or both stacked along some batch or * dim.
|
| freqs_cos (torch.Tensor): Precomputed cosine frequency tensor.
|
| freqs_sin (torch.Tensor): Precomputed sine frequency tensor.
|
|
|
| Returns:
|
| torch.Tensor: The input tensor with rotary embeddings applied.
|
| """
|
|
|
| xqk_even = xqk[..., 0::2]
|
| xqk_odd = xqk[..., 1::2]
|
|
|
|
|
| cos_part = (xqk_even * freqs_cos - xqk_odd * freqs_sin).type_as(xqk)
|
| sin_part = (xqk_even * freqs_sin + xqk_odd * freqs_cos).type_as(xqk)
|
|
|
|
|
| out = torch.stack([cos_part, sin_part], dim=-1).flatten(-2)
|
| return out
|
|
|