| |
|
|
| |
|
|
| from typing import Optional, Tuple, Union |
|
|
| import torch |
| from einops import rearrange, repeat |
|
|
| from fla.ops.rotary import apply_rotary |
|
|
|
|
| def rotate_half(x, interleaved=False): |
| if not interleaved: |
| x1, x2 = x.chunk(2, dim=-1) |
| return torch.cat((-x2, x1), dim=-1) |
| else: |
| x1, x2 = x[..., ::2], x[..., 1::2] |
| return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2) |
|
|
|
|
| def apply_rotary_emb_torch(x, cos, sin, interleaved=False): |
| """ |
| x: (batch_size, seqlen, nheads, headdim) |
| cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) |
| """ |
| ro_dim = cos.shape[-1] * 2 |
| assert ro_dim <= x.shape[-1] |
| cos = repeat( |
| cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") |
| sin = repeat( |
| sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") |
| return torch.cat( |
| [x[..., :ro_dim] * cos + |
| rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]], |
| dim=-1, |
| ) |
|
|
|
|
| class ApplyRotaryEmb(torch.autograd.Function): |
| @staticmethod |
| def forward( |
| ctx, |
| x, |
| cos, |
| sin, |
| interleaved=False, |
| inplace=False, |
| seqlen_offsets: Union[int, torch.Tensor] = 0, |
| cu_seqlens: Optional[torch.Tensor] = None, |
| max_seqlen: Optional[int] = None, |
| ): |
| out = apply_rotary( |
| x, |
| cos, |
| sin, |
| seqlen_offsets=seqlen_offsets, |
| cu_seqlens=cu_seqlens, |
| max_seqlen=max_seqlen, |
| interleaved=interleaved, |
| inplace=inplace, |
| ) |
| if isinstance(seqlen_offsets, int): |
| |
| ctx.save_for_backward(cos, sin, cu_seqlens) |
| ctx.seqlen_offsets = seqlen_offsets |
| else: |
| ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets) |
| ctx.seqlen_offsets = None |
| ctx.interleaved = interleaved |
| ctx.inplace = inplace |
| ctx.max_seqlen = max_seqlen |
| return out if not inplace else x |
|
|
| @staticmethod |
| def backward(ctx, do): |
| seqlen_offsets = ctx.seqlen_offsets |
| if seqlen_offsets is None: |
| cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors |
| else: |
| cos, sin, cu_seqlens = ctx.saved_tensors |
| |
| |
| if not ctx.interleaved and not ctx.inplace: |
| do = do.clone() |
| dx = apply_rotary( |
| do, |
| cos, |
| sin, |
| seqlen_offsets=seqlen_offsets, |
| cu_seqlens=cu_seqlens, |
| max_seqlen=ctx.max_seqlen, |
| interleaved=ctx.interleaved, |
| inplace=ctx.inplace, |
| conjugate=True, |
| ) |
| return dx, None, None, None, None, None, None, None |
|
|
|
|
| def apply_rotary_emb( |
| x, |
| cos, |
| sin, |
| interleaved=False, |
| inplace=False, |
| seqlen_offsets: Union[int, torch.Tensor] = 0, |
| cu_seqlens: Optional[torch.Tensor] = None, |
| max_seqlen: Optional[int] = None, |
| ): |
| """ |
| Arguments: |
| x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None |
| else (total_seqlen, nheads, headdim) |
| cos, sin: (seqlen_rotary, rotary_dim / 2) |
| interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead |
| of 1st half and 2nd half (GPT-NeoX style). |
| inplace: if True, apply rotary embedding in-place. |
| seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount. |
| Most commonly used in inference when we have KV cache. |
| cu_seqlens: (batch + 1,) or None |
| max_seqlen: int |
| Return: |
| out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None |
| else (total_seqlen, nheads, headdim) |
| rotary_dim must be <= headdim |
| Apply rotary embedding to the first rotary_dim of x. |
| """ |
| return ApplyRotaryEmb.apply( |
| x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen |
| ) |
|
|
|
|
| |
| apply_rotary_emb_func = apply_rotary_emb |
|
|
|
|
| class RotaryEmbedding(torch.nn.Module): |
| """ |
| The rotary position embeddings from RoFormer_ (Su et. al). |
| A crucial insight from the method is that the query and keys are |
| transformed by rotation matrices which depend on the relative positions. |
| |
| Other implementations are available in the Rotary Transformer repo_ and in |
| GPT-NeoX_, GPT-NeoX was an inspiration |
| |
| .. _RoFormer: https://arxiv.org/abs/2104.09864 |
| .. _repo: https://github.com/ZhuiyiTechnology/roformer |
| .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox |
| |
| If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554). |
| A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96 |
| Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py |
| """ |
|
|
| def __init__( |
| self, |
| dim: int, |
| base=10000.0, |
| interleaved=False, |
| scale_base=None, |
| pos_idx_in_fp32=True, |
| device=None, |
| ): |
| """ |
| interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead |
| of 1st half and 2nd half (GPT-NeoX style). |
| pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32, |
| otherwise they might be in lower precision. |
| This option was added because previously (before 2023-07-02), when we construct |
| the position indices, we use the dtype of self.inv_freq. In most cases this would |
| be fp32, but if the model is trained in pure bf16 (not mixed precision), then |
| self.inv_freq would be bf16, and the position indices are also in bf16. |
| Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the |
| embeddings for some positions will coincide. |
| To maintain compatibility with models previously trained in pure bf16, |
| we add this option. |
| """ |
| super().__init__() |
| self.dim = dim |
| self.base = float(base) |
| self.pos_idx_in_fp32 = pos_idx_in_fp32 |
| |
| inv_freq = self._compute_inv_freq(device) |
| self.register_buffer("inv_freq", inv_freq, persistent=False) |
| self.interleaved = interleaved |
| self.scale_base = scale_base |
| scale = ( |
| (torch.arange(0, dim, 2, device=device, |
| dtype=torch.float32) + 0.4 * dim) / (1.4 * dim) |
| if scale_base is not None |
| else None |
| ) |
| self.register_buffer("scale", scale, persistent=False) |
|
|
| self._seq_len_cached = 0 |
| self._cos_cached = None |
| self._sin_cached = None |
| self._cos_k_cached = None |
| self._sin_k_cached = None |
|
|
| def _compute_inv_freq(self, device=None): |
| return 1.0 / ( |
| self.base |
| ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim) |
| ) |
|
|
| def _update_cos_sin_cache(self, seqlen, device=None, dtype=None): |
| |
| |
| |
| if ( |
| seqlen > self._seq_len_cached |
| or self._cos_cached is None |
| or self._cos_cached.device != device |
| or self._cos_cached.dtype != dtype |
| or (self.training and self._cos_cached.is_inference()) |
| ): |
| self._seq_len_cached = seqlen |
| |
| |
| |
| if self.pos_idx_in_fp32: |
| t = torch.arange(seqlen, device=device, dtype=torch.float32) |
| |
| |
| |
| |
| if self.inv_freq.dtype != torch.float32: |
| inv_freq = self._compute_inv_freq(device=device) |
| else: |
| inv_freq = self.inv_freq |
| else: |
| t = torch.arange(seqlen, device=device, |
| dtype=self.inv_freq.dtype) |
| inv_freq = self.inv_freq |
| |
| |
| freqs = torch.outer(t, inv_freq) |
| if self.scale is None: |
| self._cos_cached = torch.cos(freqs).to(dtype) |
| self._sin_cached = torch.sin(freqs).to(dtype) |
| else: |
| power = ( |
| torch.arange(seqlen, dtype=self.scale.dtype, |
| device=self.scale.device) |
| - seqlen // 2 |
| ) / self.scale_base |
| scale = self.scale.to( |
| device=power.device) ** rearrange(power, "s -> s 1") |
| |
| self._cos_cached = (torch.cos(freqs) * scale).to(dtype) |
| self._sin_cached = (torch.sin(freqs) * scale).to(dtype) |
| self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype) |
| self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype) |
|
|
| def forward( |
| self, |
| q: torch.Tensor, |
| k: torch.Tensor, |
| seqlen_offset: Union[int, torch.Tensor] = 0, |
| max_seqlen: Optional[int] = None, |
| ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: |
| """ |
| qkv: (batch, seqlen, 3, nheads, headdim) if kv is none, |
| else it's just q of shape (batch, seqlen, nheads, headdim) |
| kv: (batch, seqlen, 2, nheads, headdim) |
| seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount. |
| Most commonly used in inference when we have KV cache. |
| If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one |
| should pass in max_seqlen, which will update the cos / sin cache up to that length. |
| Apply rotary embedding *inplace* to qkv and / or kv. |
| """ |
| seqlen = q.shape[1] |
| if max_seqlen is not None: |
| self._update_cos_sin_cache(max_seqlen, device=q.device, dtype=q.dtype) |
| elif isinstance(seqlen_offset, int): |
| self._update_cos_sin_cache(seqlen + seqlen_offset, device=q.device, dtype=q.dtype) |
| if self.scale is None: |
| q = apply_rotary_emb_func( |
| q, |
| self._cos_cached, |
| self._sin_cached, |
| interleaved=self.interleaved, |
| seqlen_offsets=seqlen_offset, |
| ) |
| k = apply_rotary_emb_func( |
| k, |
| self._cos_cached, |
| self._sin_cached, |
| interleaved=self.interleaved, |
| seqlen_offsets=seqlen_offset, |
| ) |
|
|
| else: |
| q = apply_rotary_emb_func( |
| q, |
| self._cos_cached, |
| self._sin_cached, |
| interleaved=self.interleaved, |
| seqlen_offsets=seqlen_offset, |
| ) |
| k = apply_rotary_emb_func( |
| k, |
| self._cos_k_cached, |
| self._sin_k_cached, |
| interleaved=self.interleaved, |
| seqlen_offsets=seqlen_offset, |
| ) |
|
|
| return q, k |
|
|