| |
| |
|
|
| |
| |
|
|
| import torch |
| from einops import rearrange |
| from flash_attn.ops.triton.rotary import apply_rotary |
|
|
| from typing import Optional, Tuple, Union |
|
|
|
|
| class ApplyRotaryEmbUnpad(torch.autograd.Function): |
| @staticmethod |
| def forward( |
| ctx, |
| qkv, |
| cos, |
| sin, |
| interleaved=False, |
| seqlen_offsets: Union[int, torch.Tensor] = 0, |
| cu_seqlens: Optional[torch.Tensor] = None, |
| max_seqlen: Optional[int] = None, |
| ): |
| |
| total_nnz, three, nheads, headdim = qkv.shape |
| assert three == 3 |
| if qkv.is_contiguous(): |
| |
| |
| |
| |
| qk = qkv[:, :2].view(total_nnz, -1, headdim) |
| apply_rotary( |
| qk, |
| cos, |
| sin, |
| seqlen_offsets=seqlen_offsets, |
| cu_seqlens=cu_seqlens, |
| max_seqlen=max_seqlen, |
| interleaved=interleaved, |
| inplace=True, |
| ) |
| else: |
| q, k = qkv[:, 0, :, :], qkv[:, 1, :, :] |
| apply_rotary( |
| q, |
| cos, |
| sin, |
| seqlen_offsets=seqlen_offsets, |
| cu_seqlens=cu_seqlens, |
| max_seqlen=max_seqlen, |
| interleaved=interleaved, |
| inplace=True, |
| ) |
| apply_rotary( |
| k, |
| cos, |
| sin, |
| seqlen_offsets=seqlen_offsets, |
| cu_seqlens=cu_seqlens, |
| max_seqlen=max_seqlen, |
| interleaved=interleaved, |
| inplace=True, |
| ) |
|
|
| if isinstance(seqlen_offsets, int): |
| ctx.save_for_backward(cos, sin, cu_seqlens) |
| ctx.seqlen_offsets = seqlen_offsets |
| else: |
| ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets) |
| ctx.seqlen_offsets = None |
| ctx.interleaved = interleaved |
| ctx.max_seqlen = max_seqlen |
| return qkv |
|
|
| @staticmethod |
| def backward(ctx, do): |
| seqlen_offsets = ctx.seqlen_offsets |
| if seqlen_offsets is None: |
| cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors |
| else: |
| cos, sin, cu_seqlens = ctx.saved_tensors |
| if do.is_contiguous(): |
| total_nnz, three, nheads, headdim = do.shape |
| |
| |
| |
| dqk = do[:, :2].view(total_nnz, -1, headdim) |
| apply_rotary( |
| dqk, |
| cos, |
| sin, |
| seqlen_offsets=seqlen_offsets, |
| cu_seqlens=cu_seqlens, |
| max_seqlen=ctx.max_seqlen, |
| interleaved=ctx.interleaved, |
| inplace=True, |
| conjugate=True, |
| ) |
| else: |
| dq, dk = do[:, 0, :, :], do[:, 1, :, :] |
| apply_rotary( |
| dq, |
| cos, |
| sin, |
| seqlen_offsets=seqlen_offsets, |
| cu_seqlens=cu_seqlens, |
| max_seqlen=ctx.max_seqlen, |
| interleaved=ctx.interleaved, |
| inplace=True, |
| conjugate=True, |
| ) |
| apply_rotary( |
| dk, |
| cos, |
| sin, |
| seqlen_offsets=seqlen_offsets, |
| cu_seqlens=cu_seqlens, |
| max_seqlen=ctx.max_seqlen, |
| interleaved=ctx.interleaved, |
| inplace=True, |
| conjugate=True, |
| ) |
|
|
| return do, None, None, None, None, None, None |
|
|
|
|
| def apply_rotary_emb_unpad( |
| qkv, |
| cos, |
| sin, |
| interleaved=False, |
| seqlen_offsets: Union[int, torch.Tensor] = 0, |
| cu_seqlens: Optional[torch.Tensor] = None, |
| max_seqlen: Optional[int] = None, |
| ): |
| """ |
| Arguments: |
| qkv: (total_nnz, 3, nheads, headdim) - input tensor for packed QKV. |
| cos, sin: (seqlen_rotary, rotary_dim / 2) |
| interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead |
| of 1st half and 2nd half (GPT-NeoX style). |
| inplace: if True, apply rotary embedding in-place. |
| seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount. |
| Most commonly used in inference when we have KV cache. |
| cu_seqlens: (batch + 1,) or None |
| max_seqlen: int |
| Return: |
| out: (total_nnz, dim) |
| rotary_dim must be <= headdim |
| Apply rotary embedding to the first rotary_dim of x. |
| """ |
| return ApplyRotaryEmbUnpad.apply(qkv, cos, sin, interleaved, seqlen_offsets, cu_seqlens, max_seqlen) |
|
|
|
|
| class UnpaddedRotaryEmbedding(torch.nn.Module): |
| """ |
| The rotary position embeddings applied directly to unpadded sequences. |
| """ |
|
|
| def __init__( |
| self, |
| dim: int, |
| base: float = 10000.0, |
| interleaved: bool = False, |
| max_seqlen: Optional[int] = None, |
| scale_base: Optional[bool] = None, |
| pos_idx_in_fp32: bool = True, |
| device: Optional[torch.device] = None, |
| dtype: Optional[torch.dtype] = None, |
| ): |
| """ |
| interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead |
| of 1st half and 2nd half (GPT-NeoX style). |
| pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32, |
| otherwise they might be in lower precision. |
| This option was added because previously (before 2023-07-02), when we construct |
| the position indices, we use the dtype of self.inv_freq. In most cases this would |
| be fp32, but if the model is trained in pure bf16 (not mixed precision), then |
| self.inv_freq would be bf16, and the position indices are also in bf16. |
| Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the |
| embeddings for some positions will coincide. |
| To maintain compatibility with models previously trained in pure bf16, |
| we add this option. |
| max_seqlen: if max_seqlen, device, and dtype are provided, we precompute the cos_sin_cache |
| up to max_seqlen. If the max_seqlen, device, or dtype during training/inference differ, |
| the cos_sin_cache wll be recomputed during the forward pass. |
| """ |
| super().__init__() |
| self.dim = dim |
| self.base = float(base) |
| self.pos_idx_in_fp32 = pos_idx_in_fp32 |
| |
| inv_freq = self._compute_inv_freq(device) |
| self.register_buffer("inv_freq", inv_freq, persistent=False) |
| self.interleaved = interleaved |
| self.scale_base = scale_base |
| scale = ( |
| (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim) |
| if scale_base is not None |
| else None |
| ) |
| self.register_buffer("scale", scale, persistent=False) |
|
|
| self._seq_len_cached = 0 |
| self._cos_cached = None |
| self._sin_cached = None |
| self._cos_k_cached = None |
| self._sin_k_cached = None |
|
|
| if max_seqlen is not None and device is not None and dtype is not None: |
| self._update_cos_sin_cache(max_seqlen, device=device, dtype=dtype) |
|
|
| def _compute_inv_freq(self, device=None): |
| return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)) |
|
|
| def _update_cos_sin_cache(self, seqlen, device=None, dtype=None): |
| |
| |
| |
| if ( |
| seqlen > self._seq_len_cached |
| or self._cos_cached is None |
| or self._cos_cached.device != device |
| or self._cos_cached.dtype != dtype |
| or (self.training and self._cos_cached.is_inference()) |
| ): |
| self._seq_len_cached = seqlen |
| |
| |
| |
| if self.pos_idx_in_fp32: |
| t = torch.arange(seqlen, device=device, dtype=torch.float32) |
| |
| |
| |
| |
| if self.inv_freq.dtype != torch.float32: |
| inv_freq = self._compute_inv_freq(device=device) |
| else: |
| inv_freq = self.inv_freq |
| else: |
| t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) |
| inv_freq = self.inv_freq |
| |
| |
| freqs = torch.outer(t, inv_freq) |
| if self.scale is None: |
| self._cos_cached = torch.cos(freqs).to(dtype) |
| self._sin_cached = torch.sin(freqs).to(dtype) |
| else: |
| power = ( |
| torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2 |
| ) / self.scale_base |
| scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1") |
| |
| self._cos_cached = (torch.cos(freqs) * scale).to(dtype) |
| self._sin_cached = (torch.sin(freqs) * scale).to(dtype) |
| self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype) |
| self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype) |
|
|
| def forward( |
| self, |
| qkv: torch.Tensor, |
| cu_seqlens: torch.Tensor, |
| max_seqlen: Optional[int] = None, |
| seqlen_offset: Union[int, torch.Tensor] = 0, |
| ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: |
| """ |
| qkv: (total_nnz, 3, nheads, headdim) |
| cu_seqlens: (batch + 1,) cumulative sequence lengths |
| max_seqlen: int max seq length in the batch |
| seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount. |
| Most commonly used in inference when we have KV cache. |
| If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one |
| should pass in max_seqlen, which will update the cos / sin cache up to that length. |
| Apply rotary embedding *inplace* to qkv. |
| """ |
| if max_seqlen is not None: |
| self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype) |
|
|
| qkv = apply_rotary_emb_unpad( |
| qkv, |
| self._cos_cached, |
| self._sin_cached, |
| interleaved=self.interleaved, |
| seqlen_offsets=seqlen_offset, |
| cu_seqlens=cu_seqlens, |
| max_seqlen=max_seqlen, |
| ) |
|
|
| return qkv |
|
|
| def extra_repr(self) -> str: |
| return f"dim={self.dim}, base={self.base}, scale_base={self.scale_base}" |
|
|