| |
| |
|
|
| |
| from typing import Tuple |
|
|
| import torch |
|
|
|
|
| def rotate_half(x): |
| x1, x2 = x.chunk(2, dim=-1) |
| return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
| def apply_rotary_pos_emb(x, cos, sin, offset: int = 0): |
| |
| assert ( |
| cos.shape[1] >= offset + x.shape[1] |
| ), f"Offset and/or input sequence is too large,\ |
| \n offset: {offset}, seq_len: {x.shape[1]}, max: {cos.shape[1]}" |
|
|
| |
| cos_out = cos[:, offset : offset + x.shape[1], :, :] |
| sin_out = sin[:, offset : offset + x.shape[1], :, :] |
|
|
| return (x * cos_out) + (rotate_half(x) * sin_out) |
|
|
|
|
| class RotaryEmbedding(torch.nn.Module): |
| """ |
| The rotary position embeddings from RoFormer_ (Su et. al). |
| A crucial insight from the method is that the query and keys are |
| transformed by rotation matrices which depend on the relative positions. |
| |
| Other implementations are available in the Rotary Transformer repo_ and in |
| GPT-NeoX_, GPT-NeoX was an inspiration |
| |
| .. _RoFormer: https://arxiv.org/abs/2104.09864 |
| .. _repo: https://github.com/ZhuiyiTechnology/roformer |
| .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox |
| |
| |
| .. warning: Please note that this embedding is not registered on purpose, as it is transformative |
| (it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis |
| """ |
|
|
| def __init__(self, dim_model: int, seq_len: int, *_, **__): |
| super().__init__() |
| |
| self.dim_model = dim_model |
| self.register_buffer("inv_freq", torch.zeros(self.dim_model // 2)) |
|
|
| self._cos_cached = None |
| self._sin_cached = None |
| self._seq_len_cached = 0 |
| self.seq_len = seq_len |
| self.reset_parameters() |
|
|
| def reset_parameters(self): |
| self.inv_freq = 1.0 / (10000 ** (torch.arange(0, self.dim_model, 2).float() / self.dim_model)) |
| self._update_cos_sin_tables(self.seq_len) |
|
|
| def _update_cos_sin_tables(self, seq_len: int = None, device: torch.device = None, dtype: torch.dtype = None): |
| |
| |
| if seq_len is None or seq_len < self._seq_len_cached: |
| seq_len = self._seq_len_cached |
|
|
| |
| |
| if seq_len > self._seq_len_cached or self._cos_cached.device != device or self._cos_cached.dtype != dtype: |
| self._seq_len_cached = seq_len |
| t = torch.arange(seq_len, device=device, dtype=torch.float32) |
| freqs = torch.einsum("i,j->ij", t, self.inv_freq.to(dtype)) |
| emb = torch.cat((freqs, freqs), dim=-1).to(device) |
|
|
| self._cos_cached = emb.cos()[None, :, None, :].to(dtype) |
| self._sin_cached = emb.sin()[None, :, None, :].to(dtype) |
|
|
| def forward(self, q: torch.Tensor, k: torch.Tensor, offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]: |
| self._update_cos_sin_tables(k.shape[1] + offset, device=k.device, dtype=k.dtype) |
| return ( |
| apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached, offset), |
| apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached, offset), |
| ) |
|
|
|
|
| class RotaryWithCast(RotaryEmbedding): |
| def forward(self, q, k, v, offset: int = 0): |
| q, k = super().forward(q, k, offset) |
| return q.to(v.dtype), k.to(v.dtype), v |
|
|