| | |
| | |
| | |
| |
|
| | from torch import nn |
| | import math |
| | import torch |
| | from ..utils.compile import torch_compile_lazy |
| |
|
| |
|
| | @torch_compile_lazy |
| | def apply_rope( |
| | q: torch.Tensor, |
| | k: torch.Tensor, |
| | offset: torch.Tensor, |
| | max_period: float = 10_000, |
| | time_before_heads: bool = False, |
| | ): |
| | """ |
| | Args: |
| | q (torch.Tensor): queries, shape `[B, T, H, D]`. |
| | k (torch.Tensor): keys, shape `[B, T, H, D]`. |
| | offset (int): current offset, e.g. when streaming. |
| | max_period (float): maximum period for the cos and sin. |
| | time_before_heads (bool): if True, expected [B, T, H, D], else [B, H, T ,D] |
| | """ |
| |
|
| | if time_before_heads: |
| | B, T, H, D = q.shape |
| | else: |
| | B, H, T, D = q.shape |
| | assert k.shape == q.shape |
| | assert D > 0 |
| | assert D % 2 == 0 |
| | assert max_period > 0 |
| |
|
| | ds = torch.arange(D // 2, device=q.device, dtype=torch.float32) |
| | freqs = torch.exp(ds * (-math.log(max_period) * 2 / D)) |
| | ts = offset.float() + torch.arange(T, device=q.device, dtype=torch.float32) |
| | if time_before_heads: |
| | ts = ts.view(-1, 1, 1) |
| | else: |
| | ts = ts.view(1, -1, 1) |
| |
|
| | dims = q.shape[:-1] |
| | q = q.view(*dims, D // 2, 2) |
| | k = k.view(*dims, D // 2, 2) |
| |
|
| | |
| | qr = q[..., 0].float() |
| | qi = q[..., 1].float() |
| |
|
| | kr = k[..., 0].float() |
| | ki = k[..., 1].float() |
| |
|
| | rotr = torch.cos(freqs * ts) |
| | roti = torch.sin(freqs * ts) |
| | qor = qr * rotr - qi * roti |
| | qoi = qr * roti + qi * rotr |
| |
|
| | kor = kr * rotr - ki * roti |
| | koi = kr * roti + ki * rotr |
| |
|
| | dtype = q.dtype |
| | qo = torch.stack([qor.to(dtype), qoi.to(dtype)], dim=-1) |
| | ko = torch.stack([kor.to(dtype), koi.to(dtype)], dim=-1) |
| |
|
| | return qo.view(*dims, D), ko.view(*dims, D) |
| |
|
| |
|
| | class RotaryEmbedding(nn.Module): |
| | """Rotary positional embedding (RoPE) from [Su et al 2022](https://arxiv.org/abs/2104.09864). |
| | |
| | Args: |
| | max_period (float): Maximum period of the rotation frequencies. |
| | """ |
| |
|
| | def __init__(self, max_period: float = 10000.0): |
| | super().__init__() |
| | self.max_period = max_period |
| |
|
| | def forward( |
| | self, |
| | q: torch.Tensor, |
| | k: torch.Tensor, |
| | offset: torch.Tensor, |
| | time_before_heads: bool = False, |
| | ): |
| | """Apply rope rotation to query or key tensor.""" |
| | return apply_rope(q, k, offset, self.max_period, time_before_heads) |
| |
|