diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dd9c685dfffa721cf81205065f297a95d9a9c2de Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/__pycache__/bert_padding.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/__pycache__/bert_padding.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..05b977e7f837401b665dbc54574ff9044403cfc8 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/__pycache__/bert_padding.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/__pycache__/flash_attn_interface.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/__pycache__/flash_attn_interface.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e0ce0a892add1b2e726190070cb99fbfc7625d2 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/__pycache__/flash_attn_interface.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/__pycache__/flash_attn_triton.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/__pycache__/flash_attn_triton.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..73d2c4ec8f026e21c5058f02efd2185e72993809 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/__pycache__/flash_attn_triton.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/__pycache__/flash_attn_triton_og.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/__pycache__/flash_attn_triton_og.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea276889995ec8f226a7797494a6e0734f2c87b0 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/__pycache__/flash_attn_triton_og.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/__pycache__/flash_blocksparse_attention.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/__pycache__/flash_blocksparse_attention.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd5b3bafddce15c40083bc46bfdb3cdd67682bcd Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/__pycache__/flash_blocksparse_attention.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/__pycache__/flash_blocksparse_attn_interface.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/__pycache__/flash_blocksparse_attn_interface.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c6b10e867ccc78e905f5a8d5e1633862366cdea Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/__pycache__/flash_blocksparse_attn_interface.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/__pycache__/fused_softmax.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/__pycache__/fused_softmax.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc296480f1657582f7cefd608092e23367aef4d1 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/__pycache__/fused_softmax.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/layers/__init__.py b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/layers/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/layers/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1fc0d475563fea60b37906ca45ea162b62503b62 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/layers/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/layers/__pycache__/patch_embed.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/layers/__pycache__/patch_embed.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f1b329e1558e835d35368e6045304b70b0514ff Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/layers/__pycache__/patch_embed.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/layers/__pycache__/rotary.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/layers/__pycache__/rotary.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7836aadc9ed0981ba8f2a418a8ac1e664a092844 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/layers/__pycache__/rotary.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/layers/patch_embed.py b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/layers/patch_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..05562f8e8bcdb58e947c6f402a49eacd2d031871 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/layers/patch_embed.py @@ -0,0 +1,67 @@ +# We use the same API as https://github.com/rwightman/pytorch-image-models/blob/v0.6.11/timm/models/layers/patch_embed.py +# But we use nn.Linear instead of Conv2d and it's about 8x faster. + +from functools import partial + +import torch.nn as nn +from einops import rearrange +from torch import _assert +from torch.nn.modules.utils import _pair + +try: + from flash_attn.ops.fused_dense import FusedDense +except ImportError: + FusedDense = None + + +class PatchEmbed(nn.Module): + """2D Image to Patch Embedding""" + + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + norm_layer=None, + flatten=True, + bias=True, + fused_bias_fc=False, + ): + super().__init__() + img_size = _pair(img_size) + patch_size = _pair(patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten + if fused_bias_fc and FusedDense is None: + raise ImportError("fused_dense is not installed") + + linear_cls = nn.Linear if not fused_bias_fc or not bias else FusedDense + self.proj = linear_cls(in_chans * patch_size[0] * patch_size[1], embed_dim, bias=bias) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x): + _, _, H, W = x.shape + _assert( + H == self.img_size[0], + f"Input image height ({H}) doesn't match model ({self.img_size[0]}).", + ) + _assert( + W == self.img_size[1], + f"Input image width ({W}) doesn't match model ({self.img_size[1]}).", + ) + x = self.proj( + rearrange( + x, + "b c (h p1) (w p2) -> b h w (c p1 p2)", + p1=self.patch_size[0], + p2=self.patch_size[1], + ) + ) + if self.flatten: + x = rearrange(x, "b h w c -> b (h w) c") + x = self.norm(x) + return x diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/layers/rotary.py b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/layers/rotary.py new file mode 100644 index 0000000000000000000000000000000000000000..215f518ea7ca5e0ebe6bcd94657cc84eff6e975c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/layers/rotary.py @@ -0,0 +1,481 @@ +# Copyright (c) 2023, Tri Dao. + +import math +from typing import Optional, Tuple, Union + +import torch +from einops import rearrange, repeat +from flash_attn.ops.triton.rotary import apply_rotary + + +def rotate_half(x, interleaved=False): + if not interleaved: + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + else: + x1, x2 = x[..., ::2], x[..., 1::2] + return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2) + + +def apply_rotary_emb_torch(x, cos, sin, interleaved=False): + """ + x: (batch_size, seqlen, nheads, headdim) + cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) + """ + ro_dim = cos.shape[-1] * 2 + assert ro_dim <= x.shape[-1] + cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") + sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") + return torch.cat( + [x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]], + dim=-1, + ) + + +class ApplyRotaryEmb(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + cos, + sin, + interleaved=False, + inplace=False, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + ): + out = apply_rotary( + x, + cos, + sin, + seqlen_offsets=seqlen_offsets, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + interleaved=interleaved, + inplace=inplace, + ) + if isinstance(seqlen_offsets, int): + ctx.save_for_backward(cos, sin, cu_seqlens) # Can't save int with save_for_backward + ctx.seqlen_offsets = seqlen_offsets + else: + ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets) + ctx.seqlen_offsets = None + ctx.interleaved = interleaved + ctx.inplace = inplace + ctx.max_seqlen = max_seqlen + return out if not inplace else x + + @staticmethod + def backward(ctx, do): + seqlen_offsets = ctx.seqlen_offsets + if seqlen_offsets is None: + cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors + else: + cos, sin, cu_seqlens = ctx.saved_tensors + # TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with + # "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works. + if not ctx.interleaved and not ctx.inplace: + do = do.clone() + dx = apply_rotary( + do, + cos, + sin, + seqlen_offsets=seqlen_offsets, + cu_seqlens=cu_seqlens, + max_seqlen=ctx.max_seqlen, + interleaved=ctx.interleaved, + inplace=ctx.inplace, + conjugate=True, + ) + return dx, None, None, None, None, None, None, None + + +def apply_rotary_emb( + x, + cos, + sin, + interleaved=False, + inplace=False, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, +): + """ + Arguments: + x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None + else (total_seqlen, nheads, headdim) + cos, sin: (seqlen_rotary, rotary_dim / 2) + interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead + of 1st half and 2nd half (GPT-NeoX style). + inplace: if True, apply rotary embedding in-place. + seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount. + Most commonly used in inference when we have KV cache. + cu_seqlens: (batch + 1,) or None + max_seqlen: int + Return: + out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None + else (total_seqlen, nheads, headdim) + rotary_dim must be <= headdim + Apply rotary embedding to the first rotary_dim of x. + """ + return ApplyRotaryEmb.apply( + x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen + ) + + +# For backward compatibility +apply_rotary_emb_func = apply_rotary_emb + + +class ApplyRotaryEmbQKV_(torch.autograd.Function): + @staticmethod + def forward( + ctx, + qkv, + cos, + sin, + cos_k=None, + sin_k=None, + interleaved=False, + seqlen_offsets: Union[int, torch.Tensor] = 0, + ): + batch, seqlen, three, nheads, headdim = qkv.shape + assert three == 3 + if cos_k is None and sin_k is None and qkv.is_contiguous(): + # Call 1 kernel instead of 2 kernels + # We need qkv to be contiguous so that when we reshape to combine (3, nheads) + # dimensions, we get the same tensor + # qk = rearrange(qkv[:, :, :2], "b s t h d -> b s (t h) d") + qk = qkv[:, :, :2].reshape(batch, seqlen, -1, headdim) + apply_rotary( + qk, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True + ) + else: + cos_k = cos if cos_k is None else cos_k + sin_k = sin if sin_k is None else sin_k + q, k = qkv[:, :, 0], qkv[:, :, 1] + apply_rotary(q, cos, sin, seqlen_offsets, interleaved=interleaved, inplace=True) + apply_rotary(k, cos_k, sin_k, seqlen_offsets, interleaved=interleaved, inplace=True) + ctx.save_for_backward(cos, sin, cos_k, sin_k) + if isinstance(seqlen_offsets, int): + ctx.save_for_backward(cos, sin, cos_k, sin_k) + ctx.seqlen_offsets = seqlen_offsets + else: + ctx.save_for_backward(cos, sin, cos_k, sin_k, seqlen_offsets) + ctx.seqlen_offsets = None + ctx.interleaved = interleaved + return qkv + + @staticmethod + def backward(ctx, dqkv): + seqlen_offsets = ctx.seqlen_offsets + if seqlen_offsets is None: + cos, sin, cos_k, sin_k, seqlen_offsets = ctx.saved_tensors + else: + cos, sin, cos_k, sin_k = ctx.saved_tensors + if cos_k is None and sin_k is None and dqkv.is_contiguous(): + # Call 1 kernel instead of 2 kernels + # We need dqkv to be contiguous so that when we reshape to combine (3, nheads) + # dimensions, we get the same tensor + dqk = rearrange(dqkv[:, :, :2], "b s t h d -> b s (t h) d") + apply_rotary( + dqk, + cos, + sin, + seqlen_offsets=seqlen_offsets, + interleaved=ctx.interleaved, + inplace=True, + conjugate=True, + ) + else: + cos_k = cos if cos_k is None else cos_k + sin_k = sin if sin_k is None else sin_k + dq, dk = dqkv[:, :, 0], dqkv[:, :, 1] + apply_rotary( + dq, cos, sin, seqlen_offsets, interleaved=ctx.interleaved, inplace=True, conjugate=True + ) + apply_rotary( + dk, + cos_k, + sin_k, + seqlen_offsets, + interleaved=ctx.interleaved, + inplace=True, + conjugate=True, + ) + return dqkv, None, None, None, None, None, None + + +def apply_rotary_emb_qkv_( + qkv, + cos, + sin, + cos_k=None, + sin_k=None, + interleaved=False, + seqlen_offsets: Union[int, torch.Tensor] = 0, +): + """ + Arguments: + qkv: (batch_size, seqlen, 3, nheads, headdim) + cos, sin: (seqlen, rotary_dim / 2) + cos_k, sin_k: (seqlen, rotary_dim / 2), optional + interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of + 1st half and 2nd half (GPT-NeoX style). + seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount. + Most commonly used in inference when we have KV cache. + Return: + qkv: (batch_size, seqlen, 3, nheads, headdim) + rotary_dim must be <= headdim + Apply rotary embedding *inplace* to the first rotary_dim of Q and K. + """ + return ApplyRotaryEmbQKV_.apply(qkv, cos, sin, cos_k, sin_k, interleaved, seqlen_offsets) + + +class ApplyRotaryEmbKV_(torch.autograd.Function): + @staticmethod + def forward(ctx, kv, cos, sin, interleaved=False, seqlen_offsets: Union[int, torch.Tensor] = 0): + batch, seqlen, two, nheads, headdim = kv.shape + assert two == 2 + k = kv[:, :, 0] + apply_rotary( + k, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True + ) + if isinstance(seqlen_offsets, int): + ctx.save_for_backward(cos, sin) # Can't save int with save_for_backward + ctx.seqlen_offsets = seqlen_offsets + else: + ctx.save_for_backward(cos, sin, seqlen_offsets) + ctx.seqlen_offsets = None + ctx.interleaved = interleaved + return kv + + @staticmethod + def backward(ctx, dkv): + seqlen_offsets = ctx.seqlen_offsets + if seqlen_offsets is None: + cos, sin, seqlen_offsets = ctx.saved_tensors + else: + cos, sin = ctx.saved_tensors + apply_rotary( + dkv[:, :, 0], + cos, + sin, + seqlen_offsets=seqlen_offsets, + interleaved=ctx.interleaved, + inplace=True, + conjugate=True, + ) + return dkv, None, None, None, None + + +apply_rotary_emb_kv_ = ApplyRotaryEmbKV_.apply + + +def apply_rotary_emb_kv_( + kv, + cos, + sin, + interleaved=False, + seqlen_offsets: Union[int, torch.Tensor] = 0, +): + """ + Arguments: + kv: (batch_size, seqlen, 2, nheads, headdim) + cos, sin: (seqlen, rotary_dim / 2) + interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of + 1st half and 2nd half (GPT-NeoX style). + seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount. + Most commonly used in inference when we have KV cache. + Return: + kv: (batch_size, seqlen, 2, nheads, headdim) + rotary_dim must be <= headdim + Apply rotary embedding *inplace* to the first rotary_dim of K. + """ + return ApplyRotaryEmbKV_.apply(kv, cos, sin, interleaved, seqlen_offsets) + + +class RotaryEmbedding(torch.nn.Module): + """ + The rotary position embeddings from RoFormer_ (Su et. al). + A crucial insight from the method is that the query and keys are + transformed by rotation matrices which depend on the relative positions. + + Other implementations are available in the Rotary Transformer repo_ and in + GPT-NeoX_, GPT-NeoX was an inspiration + + .. _RoFormer: https://arxiv.org/abs/2104.09864 + .. _repo: https://github.com/ZhuiyiTechnology/roformer + .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox + + If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554). + A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96 + Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py + """ + + def __init__( + self, + dim: int, + base=10000.0, + interleaved=False, + scale_base=None, + pos_idx_in_fp32=True, + device=None, + ): + """ + interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead + of 1st half and 2nd half (GPT-NeoX style). + pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32, + otherwise they might be in lower precision. + This option was added because previously (before 2023-07-02), when we construct + the position indices, we use the dtype of self.inv_freq. In most cases this would + be fp32, but if the model is trained in pure bf16 (not mixed precision), then + self.inv_freq would be bf16, and the position indices are also in bf16. + Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the + embeddings for some positions will coincide. + To maintain compatibility with models previously trained in pure bf16, + we add this option. + """ + super().__init__() + self.dim = dim + self.base = float(base) + self.pos_idx_in_fp32 = pos_idx_in_fp32 + # Generate and save the inverse frequency buffer (non trainable) + inv_freq = self._compute_inv_freq(device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.interleaved = interleaved + self.scale_base = scale_base + scale = ( + (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim) + if scale_base is not None + else None + ) + self.register_buffer("scale", scale, persistent=False) + + self._seq_len_cached = 0 + self._cos_cached = None + self._sin_cached = None + self._cos_k_cached = None + self._sin_k_cached = None + + def _compute_inv_freq(self, device=None): + return 1.0 / ( + self.base + ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim) + ) + + def _update_cos_sin_cache(self, seqlen, device=None, dtype=None): + # Reset the tables if the sequence length has changed, + # if we're on a new device (possibly due to tracing for instance), + # or if we're switching from inference mode to training + if ( + seqlen > self._seq_len_cached + or self._cos_cached is None + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype + or (self.training and self._cos_cached.is_inference()) + ): + self._seq_len_cached = seqlen + # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16 + # And the output of arange can be quite large, so bf16 would lose a lot of precision. + # However, for compatibility reason, we add an option to use the dtype of self.inv_freq. + if self.pos_idx_in_fp32: + t = torch.arange(seqlen, device=device, dtype=torch.float32) + # We want fp32 here as well since inv_freq will be multiplied with t, and the output + # will be large. Having it in bf16 will lose a lot of precision and cause the + # cos & sin output to change significantly. + # We want to recompute self.inv_freq if it was not loaded in fp32 + if self.inv_freq.dtype != torch.float32: + inv_freq = self._compute_inv_freq(device=device) + else: + inv_freq = self.inv_freq + else: + t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) + inv_freq = self.inv_freq + # Don't do einsum, it converts fp32 to fp16 under AMP + # freqs = torch.einsum("i,j->ij", t, self.inv_freq) + freqs = torch.outer(t, inv_freq) + if self.scale is None: + self._cos_cached = torch.cos(freqs).to(dtype) + self._sin_cached = torch.sin(freqs).to(dtype) + else: + power = ( + torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) + - seqlen // 2 + ) / self.scale_base + scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1") + # We want the multiplication by scale to happen in fp32 + self._cos_cached = (torch.cos(freqs) * scale).to(dtype) + self._sin_cached = (torch.sin(freqs) * scale).to(dtype) + self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype) + self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype) + + def forward( + self, + qkv: torch.Tensor, + kv: Optional[torch.Tensor] = None, + seqlen_offset: Union[int, torch.Tensor] = 0, + max_seqlen: Optional[int] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + qkv: (batch, seqlen, 3, nheads, headdim) if kv is none, + else it's just q of shape (batch, seqlen, nheads, headdim) + kv: (batch, seqlen, 2, nheads, headdim) + seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount. + Most commonly used in inference when we have KV cache. + If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one + should pass in max_seqlen, which will update the cos / sin cache up to that length. + Apply rotary embedding *inplace* to qkv and / or kv. + """ + seqlen = qkv.shape[1] + if max_seqlen is not None: + self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype) + elif isinstance(seqlen_offset, int): + self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype) + if kv is None: + if self.scale is None: + return apply_rotary_emb_qkv_( + qkv, + self._cos_cached, + self._sin_cached, + interleaved=self.interleaved, + seqlen_offsets=seqlen_offset, + ) + else: + return apply_rotary_emb_qkv_( + qkv, + self._cos_cached, + self._sin_cached, + self._cos_k_cached, + self._sin_k_cached, + interleaved=self.interleaved, + seqlen_offsets=seqlen_offset, + ) + else: + q = qkv + q = apply_rotary_emb_func( + q, + self._cos_cached, + self._sin_cached, + interleaved=self.interleaved, + inplace=True, + seqlen_offsets=seqlen_offset, + ) + if self.scale is None: + kv = apply_rotary_emb_kv_( + kv, + self._cos_cached, + self._sin_cached, + interleaved=self.interleaved, + seqlen_offsets=seqlen_offset, + ) + else: + kv = apply_rotary_emb_kv_( + kv, + self._cos_k_cached, + self._sin_k_cached, + interleaved=self.interleaved, + seqlen_offsets=seqlen_offset, + ) + return q, kv diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/losses/__init__.py b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/losses/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/losses/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4fc0605e89f54f4bbaf4eb0d04eb56c558706fd2 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/losses/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/losses/__pycache__/cross_entropy.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/losses/__pycache__/cross_entropy.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..10a096afcbc76dd26bb570c9480ecfa45ea847a7 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/losses/__pycache__/cross_entropy.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/losses/cross_entropy.py b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/losses/cross_entropy.py new file mode 100644 index 0000000000000000000000000000000000000000..d4dcf66758c353b1094218ac78f6f99db2b32224 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/losses/cross_entropy.py @@ -0,0 +1,85 @@ +# Copyright (c) 2024, Tri Dao. + +import torch +import torch.nn as nn + +from flash_attn.ops.triton.cross_entropy import cross_entropy_loss + + +class CrossEntropyLoss(nn.Module): + def __init__( + self, + ignore_index=-100, + reduction="mean", + label_smoothing=0.0, + logit_scale=1.0, + lse_square_scale=0.0, + inplace_backward=False, + process_group=None, + return_z_loss=False, + ): + """ + Arguments: + ignore_index: int. If labels == ignore_index, the loss is set to 0.0. + label_smoothing: float + lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss. + This is also referred to as "z-loss". + inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits. + This saves memory. + process_group: if not None, we're doing Tensor Parallel: each process is responsible for + one part of the vocab. The loss will be aggregated across processes. + return_z_loss: bool. If True, we return the component of the loss contributed by + the lse_square_scale value. This value is only for logging and does not support + backprop. + """ + super().__init__() + if reduction not in ["mean", "none", "sum"]: + raise NotImplementedError("Only support reduction = 'mean' or 'none' or 'sum'") + self.ignore_index = ignore_index + self.reduction = reduction + self.label_smoothing = label_smoothing + self.logit_scale = logit_scale + self.lse_square_scale = lse_square_scale + self.inplace_backward = inplace_backward + self.process_group = process_group + self.return_z_loss = return_z_loss + + def forward(self, input, target, precomputed_lse=None): + """ + Arguments: + input: (batch, vocab_size) + target: (batch,) + Returns: + losses: (batch,) if reduction is 'none', else (1,), dtype float + z_loss: (batch,) if reduction is 'none', else (1,), dtype float (if self.return_z_loss) + """ + assert input.is_cuda and target.is_cuda, "Only support CUDA tensors" + loss, z_loss = cross_entropy_loss( + input, + target, + precomputed_lse=precomputed_lse, + label_smoothing=self.label_smoothing, + logit_scale=self.logit_scale, + lse_square_scale=self.lse_square_scale, + ignore_index=self.ignore_index, + inplace_backward=self.inplace_backward, + process_group=self.process_group, + ) + if self.reduction == "mean": + loss = loss.sum() / (target != self.ignore_index).sum() + elif self.reduction == "sum": + loss = loss.sum() + else: + loss = loss + + if not self.return_z_loss: + return loss + + if self.reduction == "mean": + z_loss = z_loss.sum() / (target != self.ignore_index).sum() + elif self.reduction == "sum": + z_loss = z_loss.sum() + else: + z_loss = z_loss + + return loss, z_loss diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__init__.py b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc4da1d78d3477db79821fc1629e8e002d01d839 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/baichuan.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/baichuan.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b1d26f00e89d43d52c85ccff42d75bffd46c995 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/baichuan.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/bert.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/bert.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa4b47b2e571c4287b911d1dd1a20565bb550c9a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/bert.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/bigcode.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/bigcode.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..51137c106155bcb608b56690a76e9c0dc5f90adc Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/bigcode.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/btlm.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/btlm.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6beaca96c550a027521685928d18fcfb8fced5af Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/btlm.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/falcon.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/falcon.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5967eb52533e53161f77a904d8829573ca8eac52 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/falcon.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/gpt.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/gpt.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..222a07f69b84c3a1c3a13f6bedba29ac17aa3607 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/gpt.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/gpt_neox.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/gpt_neox.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..314d2418af044c09faa36bdcbd89dda37bac790f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/gpt_neox.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/gptj.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/gptj.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af5e6aec77a3f65927a23e371bf2a3b5187f3621 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/gptj.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/llama.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/llama.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3da33317fdd830d95105bd16e6afa6f150e09972 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/llama.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/opt.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/opt.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8d1bb93377b2c13fc650d6896803e354a05f9096 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/opt.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/vit.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/vit.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc936520602ef833001cfd5283db628214f2da1f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/vit.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/baichuan.py b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/baichuan.py new file mode 100644 index 0000000000000000000000000000000000000000..97d030782187afdfa22b9ad0a9a264b9f6c0a95e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/baichuan.py @@ -0,0 +1,151 @@ +# Copyright (c) 2023, GGGGGGXY, Tri Dao. + +import math +import json +import re +from pathlib import Path + +from collections import OrderedDict + +import torch +import torch.nn.functional as F + +from einops import rearrange +from transformers import GPT2Config, AutoConfig, PretrainedConfig + + +def remap_state_dict_hf_baichuan(state_dict, config): + def key_mapping_layers(key): + return re.sub(r"^model.", "transformer.", key) + + state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items()) + + # Word embedding + def key_mapping_emb(key): + return re.sub( + r"^transformer.embed_tokens.", + "transformer.embeddings.word_embeddings.", + key, + ) + + state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items()) + word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight") + # It's possible that vocab_size is padded to be a multiple of 8, for example. + pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) + vocab_size = ( + math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple) + * pad_vocab_size_multiple + ) + state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad( + word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) + ) + if getattr(config, "tie_word_embeddings"): + state_dict["lm_head.weight"] = state_dict[ + "transformer.embeddings.word_embeddings.weight" + ] + else: + output_embeddings = state_dict.pop("lm_head.weight") + # Need to recompute vocab_size since Baichuan shards the word embeddings and output embeddings + # differently. + vocab_size = ( + math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple) + * pad_vocab_size_multiple + ) + # It's possible that vocab_size is padded to be a multiple of 8, for example. + state_dict["lm_head.weight"] = F.pad( + output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0]) + ) + + # LayerNorm + def key_mapping_ln(key): + key = re.sub(r"^transformer.norm.", r"transformer.ln_f.", key) + key = re.sub( + r"^transformer.layers.(\d+).input_layernorm.", + r"transformer.layers.\1.norm1.", + key, + ) + key = re.sub( + r"^transformer.layers.(\d+).post_attention_layernorm.", + r"transformer.layers.\1.norm2.", + key, + ) + return key + + state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) + + # MLP + for l in range(config.n_layer): + w1 = state_dict.pop(f"transformer.layers.{l}.mlp.gate_proj.weight") + w3 = state_dict.pop(f"transformer.layers.{l}.mlp.up_proj.weight") + # Our ordering is different + state_dict[f"transformer.layers.{l}.mlp.fc1.weight"] = torch.cat( + [w3, w1], dim=0 + ) + + def key_mapping_mlp(key): + return re.sub( + r"^transformer.layers.(\d+).mlp.down_proj.", + r"transformer.layers.\1.mlp.fc2.", + key, + ) + + state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) + + # Attention + def key_mapping_attn(key): + key = re.sub( + r"^transformer.layers.(\d+).self_attn.W_pack.", + r"transformer.layers.\1.mixer.Wqkv.", + key, + ) + key = re.sub( + r"^transformer.layers.(\d+).self_attn.o_proj.", + r"transformer.layers.\1.mixer.out_proj.", + key, + ) + return key + + state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) + for l in range(config.n_layer): + # pop rotary_emb.inv_freq from state dict + state_dict.pop(f"transformer.layers.{l}.self_attn.rotary_emb.inv_freq", None) + return state_dict + + +def baichuan_config_to_gpt2_config(baichuan_config: PretrainedConfig) -> GPT2Config: + # HACK: the config doesn't have say whether it's rotary or alibi. + # So we have to infer from the hidden size (7B -> rotary, 13B -> alibi). + # HACK: the config doesn't have say whether it uses norm head. + # So we have to infer from the vocab size + # (v1, vocab size 64k, no norm head; v2, vocab size 128k, norm head). + use_rotary = baichuan_config.hidden_size < 5000 + return GPT2Config( + vocab_size=baichuan_config.vocab_size, + n_positions=0, # No absolute position embedding + n_embd=baichuan_config.hidden_size, + n_layer=baichuan_config.num_hidden_layers, + n_head=baichuan_config.num_attention_heads, + n_inner=baichuan_config.intermediate_size, + activation_function="swiglu", # Hardcode since HF calls it 'silu' + # baichuan doesn't have dropout, idk if it's because they only release the inference code + resid_pdrop=0.0, + embd_pdrop=0.0, + attn_pdrop=0.0, + layer_norm_epsilon=baichuan_config.rms_norm_eps, + initializer_range=baichuan_config.initializer_range, + bos_token_id=baichuan_config.bos_token_id, + eos_token_id=baichuan_config.eos_token_id, + # These are new arguments not in the original GPT2Config + pad_token_id=baichuan_config.pad_token_id, # Idk if this does anything + rms_norm=True, + rotary_emb_fraction=1.0 if use_rotary else 0.0, + rotary_emb_interleaved=False, + use_alibi=not use_rotary, + use_flash_attn=not use_rotary, # Alibi code path requires flash_attn + tie_word_embeddings=False, + norm_head=baichuan_config.vocab_size > 70000, + qkv_proj_bias=False, + out_proj_bias=False, + mlp_fc1_bias=False, + mlp_fc2_bias=False, + ) diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/bert.py b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/bert.py new file mode 100644 index 0000000000000000000000000000000000000000..33d6935202a1b99393ef34d56a6b4fa0e188ab57 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/bert.py @@ -0,0 +1,764 @@ +# Copyright (c) 2022, Tri Dao. +# This BERT implementation is based on our MLPerf 2.0 and MLPerf 2.1 BERT implementation. +# https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py +# https://github.com/mlcommons/training_results_v2.1/blob/main/Azure-HazyResearch/benchmarks/bert/implementations/ND96amsr_A100_v4/modeling.py + +# Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py + +import logging +import re +from collections import OrderedDict +from collections.abc import Sequence +from functools import partial +from typing import Any, Mapping + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from transformers import BertConfig, PretrainedConfig +from transformers.models.bert.modeling_bert import ( + BaseModelOutputWithPoolingAndCrossAttentions, + BertForPreTrainingOutput, +) + +from flash_attn.bert_padding import ( + index_first_axis, + index_first_axis_residual, + pad_input, + unpad_input, +) +from flash_attn.modules.block import Block +from flash_attn.modules.embedding import BertEmbeddings +from flash_attn.modules.mha import MHA +from flash_attn.modules.mlp import FusedMLP, Mlp +from flash_attn.utils.pretrained import state_dict_from_pretrained + +try: + from flash_attn.ops.fused_dense import FusedDense +except ImportError: + FusedDense = None + +try: + from flash_attn.ops.triton.layer_norm import layer_norm_fn +except ImportError: + layer_norm_fn = None + + +try: + from flash_attn.losses.cross_entropy import CrossEntropyLoss +except ImportError: + CrossEntropyLoss = None + + +logger = logging.getLogger(__name__) + + +def create_mixer_cls(config, cross_attn=False, return_residual=False): + use_flash_attn = getattr(config, "use_flash_attn", False) + fused_bias_fc = getattr(config, "fused_bias_fc", False) + rotary_kwargs = {} + if config.position_embedding_type == "rotary": + rotary_kwargs["rotary_emb_dim"] = getattr(config, "rotary_emb_dim", config.hidden_size) + rotary_kwargs["rotary_emb_base"] = getattr(config, "rotary_emb_base", 10000.0) + rotary_kwargs["rotary_emb_scale_base"] = getattr(config, "rotary_emb_scale_base", None) + rotary_kwargs["rotary_emb_interleaved"] = getattr(config, "rotary_emb_interleaved", False) + mixer_cls = partial( + MHA, + num_heads=config.num_attention_heads, + cross_attn=cross_attn, + dropout=config.attention_probs_dropout_prob, + causal=False, + fused_bias_fc=fused_bias_fc, + use_flash_attn=use_flash_attn, + return_residual=return_residual, + **rotary_kwargs, + ) + return mixer_cls + + +def create_mlp_cls(config, layer_idx=None, return_residual=False): + inner_dim = config.intermediate_size + fused_mlp = getattr(config, "fused_mlp", False) + if fused_mlp: + assert config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"], ( + "fused_mlp only " "supports approximate gelu" + ) + if not fused_mlp: + approximate = ( + "tanh" + if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"] + else "none" + ) + mlp_cls = partial( + Mlp, + hidden_features=inner_dim, + activation=partial(F.gelu, approximate=approximate), + return_residual=return_residual, + ) + else: + if FusedMLP is None: + raise ImportError("fused_dense is not installed") + mlp_checkpoint_lvl = getattr(config, "mlp_checkpoint_lvl", 0) + # mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer + if isinstance(mlp_checkpoint_lvl, Sequence): + assert layer_idx is not None + mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx] + mlp_cls = partial( + FusedMLP, + hidden_features=inner_dim, + checkpoint_lvl=mlp_checkpoint_lvl, + return_residual=return_residual, + ) + return mlp_cls + + +def create_block(config, layer_idx=None): + last_layer_subset = getattr(config, "last_layer_subset", False) + cross_attn = last_layer_subset and layer_idx == config.num_hidden_layers - 1 + # TD [2022-12-19]: For cross attention (last layer), we actually want to return the + # residual x_kv, not residual x. But it's annoying to change the API (and it only affects + # one layer) so we just choose not to return residual in this case. + return_residual = not cross_attn + mixer_cls = create_mixer_cls(config, cross_attn, return_residual=return_residual) + mlp_cls = create_mlp_cls(config, layer_idx, return_residual=return_residual) + norm_cls = partial(nn.LayerNorm, eps=config.layer_norm_eps) + block = Block( + config.hidden_size, + mixer_cls, + mlp_cls, + norm_cls=norm_cls, + prenorm=False, + resid_dropout1=config.hidden_dropout_prob, + resid_dropout2=config.hidden_dropout_prob, + fused_dropout_add_ln=getattr(config, "fused_dropout_add_ln", False), + return_residual=return_residual, + ) + return block + + +# https://github.com/huggingface/transformers/blob/7032e0203262ebb2ebf55da8d2e01f873973e835/src/transformers/models/bert/modeling_bert.py#L748 +def _init_weights(module, initializer_range=0.02): + if isinstance(module, nn.Linear): + nn.init.normal_(module.weight, std=initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, std=initializer_range) + if module.padding_idx is not None: + nn.init.zeros_(module.weight[module.padding_idx]) + + +class BertEncoder(nn.Module): + def __init__(self, config: BertConfig): + super().__init__() + self.use_flash_attn = getattr(config, "use_flash_attn", False) + self.layers = nn.ModuleList( + [create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)] + ) + + def forward(self, hidden_states, key_padding_mask=None, subset_mask=None): + """If subset_mask is not None, we only want output for the subset of the sequence. + This means that we only compute the last layer output for these tokens. + subset_mask: (batch, seqlen), dtype=torch.bool + """ + if key_padding_mask is None or not self.use_flash_attn: + mixer_kwargs = ( + {"key_padding_mask": key_padding_mask} if key_padding_mask is not None else None + ) + for layer in self.layers: + hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs) + if subset_mask is not None: + hidden_states = hidden_states[subset_mask] + else: + batch, seqlen = hidden_states.shape[:2] + hidden_states, indices, cu_seqlens, max_seqlen_in_batch = unpad_input( + hidden_states, key_padding_mask + ) + mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch} + if subset_mask is None: + for layer in self.layers: + hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs) + hidden_states = pad_input(hidden_states, indices, batch, seqlen) + else: + for layer in self.layers[:-1]: + hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs) + if key_padding_mask is not None: + subset_idx = torch.nonzero( + subset_mask[key_padding_mask], as_tuple=False + ).flatten() + subset_seqlens = (subset_mask & key_padding_mask).sum(dim=-1, dtype=torch.int32) + subset_cu_seqlens = F.pad( + torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32), (1, 0) + ) + else: + subset_idx = torch.nonzero(subset_mask, as_tuple=False).flatten() + subset_seqlens = subset_mask.sum(dim=-1, dtype=torch.int32) + subset_cu_seqlens = F.pad( + torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32), (1, 0) + ) + hidden_states_subset, hidden_states = index_first_axis_residual( + hidden_states, subset_idx + ) + # It's ok to set max_seqlen_q to be much larger + mixer_kwargs = { + "x_kv": hidden_states, + "cu_seqlens": subset_cu_seqlens, + "max_seqlen": max_seqlen_in_batch, + "cu_seqlens_k": cu_seqlens, + "max_seqlen_k": max_seqlen_in_batch, + } + hidden_states = self.layers[-1](hidden_states_subset, mixer_kwargs=mixer_kwargs) + return hidden_states + + +class BertPooler(nn.Module): + def __init__(self, config): + super().__init__() + fused_bias_fc = getattr(config, "fused_bias_fc", False) + if fused_bias_fc and FusedDense is None: + raise ImportError("fused_dense is not installed") + linear_cls = nn.Linear if not fused_bias_fc else FusedDense + self.dense = linear_cls(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states, pool=True): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] if pool else hidden_states + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + fused_bias_fc = getattr(config, "fused_bias_fc", False) + if fused_bias_fc and FusedDense is None: + raise ImportError("fused_dense is not installed") + self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False) + if self.fused_dropout_add_ln and layer_norm_fn is None: + raise ImportError("Triton is not installed") + linear_cls = nn.Linear if not fused_bias_fc else FusedDense + self.dense = linear_cls(config.hidden_size, config.hidden_size) + approximate = ( + "tanh" + if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"] + else "none" + ) + self.transform_act_fn = nn.GELU(approximate=approximate) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + if not self.fused_dropout_add_ln: + hidden_states = self.layer_norm(hidden_states) + else: + hidden_states = layer_norm_fn( + hidden_states, self.layer_norm.weight, self.layer_norm.bias, eps=self.layer_norm.eps + ) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + fused_bias_fc = getattr(config, "fused_bias_fc", False) + if fused_bias_fc and FusedDense is None: + raise ImportError("fused_dense is not installed") + linear_cls = nn.Linear if not fused_bias_fc else FusedDense + + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = linear_cls(config.hidden_size, config.vocab_size, bias=True) + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class BertPreTrainingHeads(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, sequence_output, pooled_output): + prediction_scores = self.predictions(sequence_output) + seq_relationship_score = self.seq_relationship(pooled_output) + return prediction_scores, seq_relationship_score + + +class BertPreTrainedModel(nn.Module): + """An abstract class to handle weights initialization and + a simple interface for dowloading and loading pretrained models. + """ + + def __init__(self, config, *inputs, **kwargs): + super().__init__() + if not isinstance(config, BertConfig): + raise ValueError( + "Parameter config in `{}(config)` should be an instance of class `BertConfig`. " + "To create a model from a Google pretrained model use " + "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( + self.__class__.__name__, self.__class__.__name__ + ) + ) + self.config = config + + @classmethod + def from_pretrained(cls, model_name, config, *inputs, **kwargs): + """ + Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict. + Download and cache the pre-trained model file if needed. + + Params: + pretrained_model_name_or_path: either: + - a path or url to a pretrained model archive containing: + . `bert_config.json` a configuration file for the model + . `pytorch_model.bin` a PyTorch dump of a BertForPretraining instance + - a path or url to a pretrained model archive containing: + . `bert_config.json` a configuration file for the model + . `model.chkpt` a TensorFlow checkpoint + *inputs, **kwargs: additional input for the specific Bert class + (ex: num_labels for BertForSequenceClassification) + """ + # Instantiate model. + model = cls(config, *inputs, **kwargs) + load_return = model.load_state_dict( + remap_state_dict(state_dict_from_pretrained(model_name), config), strict=False + ) + logger.info(load_return) + return model + + +class BertModel(BertPreTrainedModel): + def __init__(self, config: BertConfig, add_pooling_layer=True): + super().__init__(config) + self.pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) + if config.vocab_size % self.pad_vocab_size_multiple != 0: + config.vocab_size += self.pad_vocab_size_multiple - ( + config.vocab_size % self.pad_vocab_size_multiple + ) + self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False) + if self.fused_dropout_add_ln and layer_norm_fn is None: + raise ImportError("Triton is not installed") + assert config.hidden_act in ["gelu", "gelu_new", "gelu_fast", "gelu_pytorch_tanh"] + + self.embeddings = BertEmbeddings( + config.hidden_size, + config.vocab_size, + config.max_position_embeddings, + config.type_vocab_size, + padding_idx=config.pad_token_id, + ) + self.emb_drop = nn.Dropout(config.hidden_dropout_prob) + self.emb_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.encoder = BertEncoder(config) + self.pooler = BertPooler(config) if add_pooling_layer else None + + self.apply(partial(_init_weights, initializer_range=config.initializer_range)) + + def forward( + self, + input_ids, + position_ids=None, + token_type_ids=None, + attention_mask=None, + masked_tokens_mask=None, + ): + """If masked_tokens_mask is not None (i.e. last_layer_subset == True in BertForPreTraining), + we only want the output for the masked tokens. This means that we only compute the last + layer output for these tokens. + masked_tokens_mask: (batch, seqlen), dtype=torch.bool + """ + hidden_states = self.embeddings( + input_ids, position_ids=position_ids, token_type_ids=token_type_ids + ) + # TD [2022-12:18]: Don't need to force residual in fp32 + # BERT puts embedding LayerNorm before embedding dropout. + if not self.fused_dropout_add_ln: + hidden_states = self.emb_ln(hidden_states) + else: + hidden_states = layer_norm_fn( + hidden_states, self.emb_ln.weight, self.emb_ln.bias, eps=self.emb_ln.eps + ) + hidden_states = self.emb_drop(hidden_states) + + if masked_tokens_mask is not None: + batch_size, seqlen = input_ids.shape[:2] + # We also need the first column for the CLS token + first_col_mask = torch.zeros( + batch_size, seqlen, dtype=torch.bool, device=input_ids.device + ) + first_col_mask[:, 0] = True + subset_mask = masked_tokens_mask | first_col_mask + else: + subset_mask = None + + sequence_output = self.encoder( + hidden_states, key_padding_mask=attention_mask, subset_mask=subset_mask + ) + + if masked_tokens_mask is None: + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + else: + # TD [2022-03-01]: the indexing here is very tricky. + if attention_mask is not None: + subset_idx = subset_mask[attention_mask] + pool_input = sequence_output[first_col_mask[attention_mask][subset_idx]] + sequence_output = sequence_output[masked_tokens_mask[attention_mask][subset_idx]] + else: + pool_input = sequence_output[first_col_mask[subset_mask]] + sequence_output = sequence_output[masked_tokens_mask[subset_mask]] + pooled_output = self.pooler(pool_input, pool=False) if self.pooler is not None else None + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + ) + + +class BertForPreTraining(BertPreTrainedModel): + def __init__(self, config: BertConfig): + super().__init__(config) + # If dense_seq_output, we only need to pass the hidden states for the masked out tokens + # (around 15%) to the classifier heads. + self.dense_seq_output = getattr(config, "dense_seq_output", False) + # If last_layer_subset, we only need the compute the last layer for a subset of tokens + # (e.g., the tokens we need to compute the masked LM loss and the next-sentence prediction). + self.last_layer_subset = getattr(config, "last_layer_subset", False) + if self.last_layer_subset: + assert self.dense_seq_output, "last_layer_subset requires dense_seq_output" + use_xentropy = getattr(config, "use_xentropy", False) + if use_xentropy and CrossEntropyLoss is None: + raise ImportError("xentropy_cuda is not installed") + loss_cls = ( + nn.CrossEntropyLoss + if not use_xentropy + else partial(CrossEntropyLoss, inplace_backward=True) + ) + + self.bert = BertModel(config) + self.cls = BertPreTrainingHeads(config) + self.mlm_loss = loss_cls(ignore_index=0) + self.nsp_loss = loss_cls(ignore_index=-1) + + # Initialize weights and apply final processing + self.apply(partial(_init_weights, initializer_range=config.initializer_range)) + self.tie_weights() + + def tie_weights(self): + self.cls.predictions.decoder.weight = self.bert.embeddings.word_embeddings.weight + + def forward( + self, + input_ids, + position_ids=None, + token_type_ids=None, + attention_mask=None, + labels=None, + next_sentence_label=None, + ): + """ + If labels are provided, they must be 0 for masked out tokens (as specified in the attention + mask). + Outputs: + if `labels` and `next_sentence_label` are not `None`: + Outputs the total_loss which is the sum of the masked language modeling loss and the next + sentence classification loss. + if `labels` or `next_sentence_label` is `None`: + Outputs a tuple comprising + - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and + - the next sentence classification logits of shape [batch_size, 2]. + + """ + masked_tokens_mask = labels > 0 if (self.last_layer_subset and labels is not None) else None + outputs = self.bert( + input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask.bool() if attention_mask is not None else None, + masked_tokens_mask=masked_tokens_mask, + ) + sequence_output, pooled_output = outputs.last_hidden_state, outputs.pooler_output + if self.dense_seq_output and labels is not None: + masked_token_idx = torch.nonzero(labels.flatten() > 0, as_tuple=False).flatten() + if not self.last_layer_subset: + sequence_output = index_first_axis( + rearrange(sequence_output, "b s d -> (b s) d"), masked_token_idx + ) + prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) + + total_loss = None + if labels is not None and next_sentence_label is not None: + if ( + self.dense_seq_output and labels is not None + ): # prediction_scores are already flattened + masked_lm_loss = self.mlm_loss( + prediction_scores, labels.flatten()[masked_token_idx] + ) + else: + masked_lm_loss = self.mlm_loss( + rearrange(prediction_scores, "... v -> (...) v"), + rearrange(labels, "... -> (...)"), + ) + next_sentence_loss = self.nsp_loss( + rearrange(seq_relationship_score, "... t -> (...) t"), + rearrange(next_sentence_label, "... -> (...)"), + ) + total_loss = masked_lm_loss.float() + next_sentence_loss.float() + + return BertForPreTrainingOutput( + loss=total_loss, + prediction_logits=prediction_scores, + seq_relationship_logits=seq_relationship_score, + ) + + +def remap_state_dict(state_dict, config: PretrainedConfig): + """ + Map the state_dict of a Huggingface BERT model to be flash_attn compatible. + """ + + # LayerNorm + def key_mapping_ln_gamma_beta(key): + key = re.sub(r"LayerNorm.gamma$", "LayerNorm.weight", key) + key = re.sub(r"LayerNorm.beta$", "LayerNorm.bias", key) + return key + + state_dict = OrderedDict((key_mapping_ln_gamma_beta(k), v) for k, v in state_dict.items()) + + # Layers + def key_mapping_layers(key): + return re.sub(r"^bert.encoder.layer.", "bert.encoder.layers.", key) + + state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items()) + + # LayerNorm + def key_mapping_ln(key): + key = re.sub(r"^bert.embeddings.LayerNorm.", "bert.emb_ln.", key) + key = re.sub( + r"^bert.encoder.layers.(\d+).attention.output.LayerNorm.(weight|bias)", + r"bert.encoder.layers.\1.norm1.\2", + key, + ) + key = re.sub( + r"^bert.encoder.layers.(\d+).output.LayerNorm.(weight|bias)", + r"bert.encoder.layers.\1.norm2.\2", + key, + ) + key = re.sub( + r"^cls.predictions.transform.LayerNorm.(weight|bias)", + r"cls.predictions.transform.layer_norm.\1", + key, + ) + return key + + state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) + + # MLP + def key_mapping_mlp(key): + key = re.sub( + r"^bert.encoder.layers.(\d+).intermediate.dense.(weight|bias)", + r"bert.encoder.layers.\1.mlp.fc1.\2", + key, + ) + key = re.sub( + r"^bert.encoder.layers.(\d+).output.dense.(weight|bias)", + r"bert.encoder.layers.\1.mlp.fc2.\2", + key, + ) + return key + + state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) + + # Attention + last_layer_subset = getattr(config, "last_layer_subset", False) + for d in range(config.num_hidden_layers): + Wq = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.query.weight") + Wk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.weight") + Wv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.weight") + bq = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.query.bias") + bk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.bias") + bv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.bias") + if not (last_layer_subset and d == config.num_hidden_layers - 1): + state_dict[f"bert.encoder.layers.{d}.mixer.Wqkv.weight"] = torch.cat( + [Wq, Wk, Wv], dim=0 + ) + state_dict[f"bert.encoder.layers.{d}.mixer.Wqkv.bias"] = torch.cat([bq, bk, bv], dim=0) + else: + state_dict[f"bert.encoder.layers.{d}.mixer.Wq.weight"] = Wq + state_dict[f"bert.encoder.layers.{d}.mixer.Wkv.weight"] = torch.cat([Wk, Wv], dim=0) + state_dict[f"bert.encoder.layers.{d}.mixer.Wq.bias"] = bq + state_dict[f"bert.encoder.layers.{d}.mixer.Wkv.bias"] = torch.cat([bk, bv], dim=0) + + def key_mapping_attn(key): + return re.sub( + r"^bert.encoder.layers.(\d+).attention.output.dense.(weight|bias)", + r"bert.encoder.layers.\1.mixer.out_proj.\2", + key, + ) + + state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) + + def key_mapping_decoder_bias(key): + return re.sub(r"^cls.predictions.bias", "cls.predictions.decoder.bias", key) + + state_dict = OrderedDict((key_mapping_decoder_bias(k), v) for k, v in state_dict.items()) + + # Word embedding + pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) + if pad_vocab_size_multiple > 1: + word_embeddings = state_dict["bert.embeddings.word_embeddings.weight"] + state_dict["bert.embeddings.word_embeddings.weight"] = F.pad( + word_embeddings, (0, 0, 0, config.vocab_size - word_embeddings.shape[0]) + ) + decoder_weight = state_dict["cls.predictions.decoder.weight"] + state_dict["cls.predictions.decoder.weight"] = F.pad( + decoder_weight, (0, 0, 0, config.vocab_size - decoder_weight.shape[0]) + ) + # If the vocab was padded, we want to set the decoder bias for those padded indices to be + # strongly negative (i.e. the decoder shouldn't predict those indices). + # TD [2022-05-09]: I don't think it affects the MLPerf training. + decoder_bias = state_dict["cls.predictions.decoder.bias"] + state_dict["cls.predictions.decoder.bias"] = F.pad( + decoder_bias, (0, config.vocab_size - decoder_bias.shape[0]), value=-100.0 + ) + + return state_dict + + +def inv_remap_state_dict(state_dict, config: PretrainedConfig): + """ + Map the state_dict of a flash_attn model to be Huggingface BERT compatible. + + This function is meant to be the inverse of remap_state_dict. + """ + # Word embedding + pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) + if pad_vocab_size_multiple > 1: + word_embeddings = state_dict["bert.embeddings.word_embeddings.weight"] + decoder_weight = state_dict["cls.predictions.decoder.weight"] + decoder_bias = state_dict["cls.predictions.decoder.bias"] + # unpad embeddings + state_dict["bert.embeddings.word_embeddings.weight"] = word_embeddings[ + : config.orig_vocab_size, : + ] + state_dict["cls.predictions.decoder.weight"] = decoder_weight[: config.orig_vocab_size, :] + state_dict["cls.predictions.decoder.bias"] = decoder_bias[: config.orig_vocab_size] + + for d in range(config.num_hidden_layers): + last_layer_subset = getattr(config, "last_layer_subset", False) + if not last_layer_subset or d != (config.num_hidden_layers - 1): + Wqkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.weight") + Wqkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.bias") + state_dict[f"bert.encoder.layers.{d}.attention.self.query.weight"] = Wqkv_weights[ + : Wqkv_weights.shape[0] // 3, : + ] + state_dict[f"bert.encoder.layers.{d}.attention.self.key.weight"] = Wqkv_weights[ + Wqkv_weights.shape[0] // 3 : 2 * Wqkv_weights.shape[0] // 3, : + ] + state_dict[f"bert.encoder.layers.{d}.attention.self.value.weight"] = Wqkv_weights[ + 2 * Wqkv_weights.shape[0] // 3 :, : + ] + state_dict[f"bert.encoder.layers.{d}.attention.self.query.bias"] = Wqkv_biases[ + : Wqkv_biases.shape[0] // 3 + ] + state_dict[f"bert.encoder.layers.{d}.attention.self.key.bias"] = Wqkv_biases[ + Wqkv_biases.shape[0] // 3 : 2 * Wqkv_biases.shape[0] // 3 + ] + state_dict[f"bert.encoder.layers.{d}.attention.self.value.bias"] = Wqkv_biases[ + 2 * Wqkv_biases.shape[0] // 3 : + ] + else: + Wq_weight = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.weight") + Wkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.weight") + Wq_bias = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.bias") + Wkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.bias") + state_dict[f"bert.encoder.layers.{d}.attention.self.query.weight"] = Wq_weight + state_dict[f"bert.encoder.layers.{d}.attention.self.key.weight"] = Wkv_weights[ + : Wkv_weights.shape[0] // 2, : + ] + state_dict[f"bert.encoder.layers.{d}.attention.self.value.weight"] = Wkv_weights[ + Wkv_weights.shape[0] // 2 :, : + ] + state_dict[f"bert.encoder.layers.{d}.attention.self.query.bias"] = Wq_bias + state_dict[f"bert.encoder.layers.{d}.attention.self.key.bias"] = Wkv_biases[ + : Wkv_biases.shape[0] // 2 + ] + state_dict[f"bert.encoder.layers.{d}.attention.self.value.bias"] = Wkv_biases[ + Wkv_biases.shape[0] // 2 : + ] + + def inv_key_mapping_ln(key): + key = re.sub(r"bert.emb_ln.", "bert.embeddings.LayerNorm.", key) + key = re.sub( + r"bert.encoder.layers.(\d+).norm1.(weight|bias)", + r"bert.encoder.layers.\1.attention.output.LayerNorm.\2", + key, + ) + key = re.sub( + r"bert.encoder.layers.(\d+).norm2.(weight|bias)", + r"bert.encoder.layers.\1.output.LayerNorm.\2", + key, + ) + key = re.sub( + r"cls.predictions.transform.layer_norm.(weight|bias)", + r"cls.predictions.transform.LayerNorm.\1", + key, + ) + return key + + def inv_key_mapping_ln_gamma_beta(key): + key = re.sub(r"LayerNorm.weight$", "LayerNorm.gamma", key) + key = re.sub(r"LayerNorm.bias$", "LayerNorm.beta", key) + return key + + def inv_key_mapping_layers(key): + return re.sub(r"bert.encoder.layers.", "bert.encoder.layer.", key) + + def inv_key_mapping_mlp(key): + key = re.sub( + r"bert.encoder.layer.(\d+).mlp.fc1.(weight|bias)", + r"bert.encoder.layer.\1.intermediate.dense.\2", + key, + ) + key = re.sub( + r"bert.encoder.layer.(\d+).mlp.fc2.(weight|bias)", + r"bert.encoder.layer.\1.output.dense.\2", + key, + ) + return key + + def inv_key_mapping_attn(key): + return re.sub( + r"bert.encoder.layer.(\d+).mixer.out_proj.(weight|bias)", + r"bert.encoder.layer.\1.attention.output.dense.\2", + key, + ) + + def inv_key_mapping_decoder_bias(key): + return re.sub(r"cls.predictions.decoder.bias", "cls.predictions.bias", key) + + state_dict = OrderedDict((inv_key_mapping_ln(key), value) for key, value in state_dict.items()) + state_dict = OrderedDict( + (inv_key_mapping_ln_gamma_beta(key), value) for key, value in state_dict.items() + ) + state_dict = OrderedDict( + (inv_key_mapping_layers(key), value) for key, value in state_dict.items() + ) + state_dict = OrderedDict((inv_key_mapping_mlp(key), value) for key, value in state_dict.items()) + state_dict = OrderedDict( + (inv_key_mapping_attn(key), value) for key, value in state_dict.items() + ) + state_dict = OrderedDict( + (inv_key_mapping_decoder_bias(key), value) for key, value in state_dict.items() + ) + + return state_dict diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/bigcode.py b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/bigcode.py new file mode 100644 index 0000000000000000000000000000000000000000..234944d4d6907fb3e1b0c2c3c315a2bee29d7775 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/bigcode.py @@ -0,0 +1,233 @@ +import math +import re +from collections import OrderedDict + +import torch +import torch.nn.functional as F +from transformers import GPT2Config, GPTBigCodeConfig, PretrainedConfig + + +def remap_state_dict_hf_bigcode(state_dict, config: PretrainedConfig): + """ + Map the state_dict of a Huggingface BigCode model to be flash_attn compatible. + """ + + # Word embedding and position embedding + def key_mapping_pos_emb(key): + return re.sub(r"^transformer.wpe.", "transformer.embeddings.position_embeddings.", key) + + state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items()) + word_embeddings = state_dict.pop("transformer.wte.weight") + # It's possible that vocab_size is padded to be a multiple of 8, for example. + pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) + vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple + state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad( + word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) + ) + state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"] + + # LayerNorm + def key_mapping_ln(key): + key = re.sub(r"^transformer.ln_f.(weight|bias)", r"transformer.ln_f.\1", key) + key = re.sub( + r"^transformer.h.(\d+).ln_(1|2).(weight|bias)", + r"transformer.layers.\1.norm\2.\3", + key, + ) + return key + + state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) + + def key_mapping_mlp(key): + key = re.sub( + r"^transformer.h.(\d+).mlp.c_fc.weight", + r"transformer.layers.\1.mlp.fc1.weight", + key, + ) + key = re.sub( + r"^transformer.h.(\d+).mlp.c_proj.weight", + r"transformer.layers.\1.mlp.fc2.weight", + key, + ) + key = re.sub( + r"^transformer.h.(\d+).mlp.c_fc.bias", + r"transformer.layers.\1.mlp.fc1.bias", + key, + ) + key = re.sub( + r"^transformer.h.(\d+).mlp.c_proj.bias", + r"transformer.layers.\1.mlp.fc2.bias", + key, + ) + return key + + state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) + + # TODO: add support for multi-head attention + assert config.multi_query, "Only multi-query attention is supported" + + # Attention + for d in range(config.num_hidden_layers): + embed_dim = config.n_embd + head_dim = embed_dim // config.n_head + + c_attn_weight = state_dict.pop(f"transformer.h.{d}.attn.c_attn.weight") + # with multi-query attention, the weights have shape (embed_dim, embed_dim + head_dim + head_dim) + # see https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py#L112 + # see also https://github.com/ggerganov/ggml/blob/dd1d575956e54c5bdc07632f25506b3b1884dbd2/examples/starcoder/convert-hf-to-ggml.py#L183 + # ((n_head + 2) * head_dim, embed_dim) -> (3 * n_heads * head_dim, hidden_dim) + q, k, v = torch.split(c_attn_weight, [embed_dim, head_dim, head_dim], dim=0) + # duplicate k, v along the first axis (head_dim, hidden_dim) -> (n_heads * head_dim, hidden_dim) + k = torch.tile(k, (config.n_head, 1)) + v = torch.tile(v, (config.n_head, 1)) + state_dict[f"transformer.layers.{d}.mixer.Wqkv.weight"] = torch.cat((q, k, v), dim=0) + + # same deal with the bias + c_attn_bias = state_dict.pop(f"transformer.h.{d}.attn.c_attn.bias") + # ((n_head + 2) * head_dim, embed_dim) -> (3 * n_heads * head_dim, hidden_dim) + q, k, v = torch.split(c_attn_bias, [embed_dim, head_dim, head_dim], dim=0) + # duplicate k, v along the first axis (head_dim, hidden_dim) -> (n_heads * head_dim, hidden_dim) + k = torch.tile(k, (config.n_head,)) + v = torch.tile(v, (config.n_head,)) + state_dict[f"transformer.layers.{d}.mixer.Wqkv.bias"] = torch.cat((q, k, v), dim=0) + + def key_mapping_attn(key): + key = re.sub( + r"^transformer.h.(\d+).attn.c_proj.weight", + r"transformer.layers.\1.mixer.out_proj.weight", + key, + ) + key = re.sub( + r"^transformer.h.(\d+).attn.c_proj.bias", + r"transformer.layers.\1.mixer.out_proj.bias", + key, + ) + return key + + state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) + + return state_dict + + +def inv_remap_state_dict_hf_bigcode(state_dict, config: PretrainedConfig): + """ + Map the state_dict of a flash_attn model to be Huggingface BigCode compatible. + + This function is meant to be the inverse of remap_state_dict_hf_bigcode. + """ + + # Word embedding and position embeddings + def inv_key_mapping_pos_emb(key): + return re.sub(r"^transformer.embeddings.position_embeddings.", "transformer.wpe.", key) + + state_dict = OrderedDict((inv_key_mapping_pos_emb(k), v) for k, v in state_dict.items()) + word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight") + + word_embeddings = word_embeddings[:, : config.vocab_size] + state_dict["transformer.wte.weight"] = word_embeddings + state_dict["lm_head.weight"] = word_embeddings + + # LayerNorm + def inv_key_mapping_ln(key): + key = re.sub(r"^transformer.ln_f.(weight|bias)", r"transformer.ln_f.\1", key) + key = re.sub( + r"^transformer.layers.(\d+).norm(1|2).(weight|bias)", + r"transformer.h.\1.ln_\2.\3", + key, + ) + return key + + state_dict = OrderedDict((inv_key_mapping_ln(k), v) for k, v in state_dict.items()) + + # MLPs + def inv_key_mapping_mlp(key): + key = re.sub( + r"^transformer.layers.(\d+).mlp.fc1.weight", + r"transformer.h.\1.mlp.c_fc.weight", + key, + ) + key = re.sub( + r"^transformer.layers.(\d+).mlp.fc2.weight", + r"transformer.h.\1.mlp.c_proj.weight", + key, + ) + key = re.sub( + r"^transformer.layers.(\d+).mlp.fc1.bias", + r"transformer.h.\1.mlp.c_fc.bias", + key, + ) + key = re.sub( + r"^transformer.layers.(\d+).mlp.fc2.bias", + r"transformer.h.\1.mlp.c_proj.bias", + key, + ) + return key + + state_dict = OrderedDict((inv_key_mapping_mlp(k), v) for k, v in state_dict.items()) + + # Attention + for d in range(config.num_hidden_layers): + embed_dim = config.n_embd + head_dim = embed_dim // config.n_head + + Wqkv_weight = state_dict.pop(f"transformer.layers.{d}.mixer.Wqkv.weight") + q, k, v = torch.split( + Wqkv_weight, [embed_dim, head_dim * config.n_head, head_dim * config.n_head], dim=0 + ) + c_attn_weight = torch.cat((q, k[:head_dim], v[:head_dim]), dim=0) + state_dict[f"transformer.h.{d}.attn.c_attn.weight"] = c_attn_weight + + # Same deal with the bias + Wqkv_bias = state_dict.pop(f"transformer.layers.{d}.mixer.Wqkv.bias") + q, k, v = torch.split( + Wqkv_bias, [embed_dim, head_dim * config.n_head, head_dim * config.n_head], dim=0 + ) + c_attn_bias = torch.cat((q, k[:head_dim], v[:head_dim]), dim=0) + state_dict[f"transformer.h.{d}.attn.c_attn.bias"] = c_attn_bias + + def inv_key_mapping_attn(key): + key = re.sub( + r"^transformer.layers.(\d+).mixer.out_proj.weight", + r"transformer.h.\1.attn.c_proj.weight", + key, + ) + key = re.sub( + r"^transformer.layers.(\d+).mixer.out_proj.bias", + r"transformer.h.\1.attn.c_proj.bias", + key, + ) + return key + + state_dict = OrderedDict((inv_key_mapping_attn(k), v) for k, v in state_dict.items()) + + return state_dict + + +def bigcode_config_to_gpt2_config(bigcode_config: GPTBigCodeConfig) -> GPT2Config: + return GPT2Config( + activation_function=bigcode_config.activation_function, + attn_pdrop=bigcode_config.attn_pdrop, + bos_token_id=bigcode_config.bos_token_id, + embd_pdrop=bigcode_config.embd_pdrop, + eos_token_id=bigcode_config.eos_token_id, + initializer_range=bigcode_config.initializer_range, + layer_norm_epsilon=bigcode_config.layer_norm_epsilon, + max_batch_size=bigcode_config.max_batch_size, + max_sequence_length=bigcode_config.max_sequence_length, + model_type=bigcode_config.model_type, + multi_query=bigcode_config.multi_query, + n_embd=bigcode_config.n_embd, + n_head=bigcode_config.n_head, + n_inner=bigcode_config.n_inner, + n_layer=bigcode_config.n_layer, + n_positions=bigcode_config.n_positions, + resid_pdrop=bigcode_config.resid_pdrop, + scale_attn_weights=bigcode_config.scale_attn_weights, + summary_activation=bigcode_config.summary_activation, + summary_first_dropout=bigcode_config.summary_first_dropout, + summary_proj_to_labels=bigcode_config.summary_proj_to_labels, + summary_type=bigcode_config.summary_type, + summary_use_proj=bigcode_config.summary_use_proj, + use_cache=bigcode_config.use_cache, + vocab_size=bigcode_config.vocab_size, + ) diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/btlm.py b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/btlm.py new file mode 100644 index 0000000000000000000000000000000000000000..295e12062320be819dd835de4a866607650431b2 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/btlm.py @@ -0,0 +1,102 @@ +# Copyright (c) 2023, Tri Dao. + +import math +import json +import re +from pathlib import Path + +from collections import OrderedDict + +import torch +import torch.nn.functional as F + +from einops import rearrange +from transformers import GPT2Config, AutoConfig, PretrainedConfig + + +def remap_state_dict_hf_btlm(state_dict, config): + # Word embedding and position embedding + def key_mapping_pos_emb(key): + return re.sub(r"^transformer.wpe.", "transformer.embeddings.position_embeddings.", key) + + if "transformer.wpe.weight" in state_dict: + state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items()) + word_embeddings = state_dict.pop("transformer.wte.weight") + # It's possible that vocab_size is padded to be a multiple of 8, for example. + pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) + vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple + state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad( + word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) + ) + state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"] + + # LayerNorm + def key_mapping_ln(key): + key = re.sub(r"^transformer.ln_f.(weight|bias)", r"transformer.ln_f.\1", key) + key = re.sub(r"^transformer.h.(\d+).ln_(1|2).(weight|bias)", r"transformer.layers.\1.norm\2.\3", key) + return key + + state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) + + # MLP + for d in range(config.num_hidden_layers): + W1 = state_dict.pop(f"transformer.h.{d}.mlp.c_fc.weight") + W3 = state_dict.pop(f"transformer.h.{d}.mlp.c_fc2.weight") + state_dict[f"transformer.layers.{d}.mlp.fc1.weight"] = torch.cat([W1.t(), W3.t()], dim=0) + b1 = state_dict.pop(f"transformer.h.{d}.mlp.c_fc.bias") + b3 = state_dict.pop(f"transformer.h.{d}.mlp.c_fc2.bias") + state_dict[f"transformer.layers.{d}.mlp.fc1.bias"] = torch.cat([b1, b3], dim=0) + W2 = state_dict.pop(f"transformer.h.{d}.mlp.c_proj.weight") + state_dict[f"transformer.layers.{d}.mlp.fc2.weight"] = W2.t() + + def key_mapping_mlp(key): + key = re.sub(r"^transformer.h.(\d+).mlp.c_proj.bias", r"transformer.layers.\1.mlp.fc2.bias", key) + return key + + state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) + + # Attention + for d in range(config.num_hidden_layers): + Wqkv = state_dict.pop(f"transformer.h.{d}.attn.c_attn.weight") + state_dict[f"transformer.layers.{d}.mixer.Wqkv.weight"] = Wqkv.t() + Wout = state_dict.pop(f"transformer.h.{d}.attn.c_proj.weight") + state_dict[f"transformer.layers.{d}.mixer.out_proj.weight"] = Wout.t() + state_dict.pop(f"transformer.relative_pe.slopes") # We don't store the Alibi slopes + + def key_mapping_attn(key): + key = re.sub(r"^transformer.h.(\d+).attn.c_attn.bias", r"transformer.layers.\1.mixer.Wqkv.bias", key) + key = re.sub( + r"^transformer.h.(\d+).attn.c_proj.bias", r"transformer.layers.\1.mixer.out_proj.bias", key + ) + return key + + state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) + + return state_dict + + +def btlm_config_to_gpt2_config(btlm_config: PretrainedConfig) -> GPT2Config: + return GPT2Config( + vocab_size=btlm_config.vocab_size, + n_positions=0 if btlm_config.position_embedding_type == "alibi" else btlm_config.n_positions, + n_embd=btlm_config.hidden_size, + n_layer=btlm_config.num_hidden_layers, + n_head=btlm_config.num_attention_heads, + n_inner=btlm_config.n_inner, + activation_function=btlm_config.activation_function, + resid_pdrop=btlm_config.resid_pdrop, + embd_pdrop=btlm_config.embd_pdrop, + attn_pdrop=btlm_config.attn_pdrop, + layer_norm_epsilon=btlm_config.layer_norm_epsilon, + initializer_range=btlm_config.initializer_range, + bos_token_id=btlm_config.bos_token_id, + eos_token_id=btlm_config.eos_token_id, + # These are new arguments not in the original GPT2Config + use_alibi=btlm_config.position_embedding_type == "alibi", + use_flash_attn=btlm_config.position_embedding_type == "alibi", # Alibi code path requires flash_attn + mup_width_scale=btlm_config.mup_width_scale, + mup_embeddings_multiplier=btlm_config.mup_embeddings_scale, + mup_output_multiplier=btlm_config.mup_output_alpha, + mup_scale_qk_dot_by_d=btlm_config.mup_scale_qk_dot_by_d, + mlp_multiple_of=1, + ) diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/falcon.py b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/falcon.py new file mode 100644 index 0000000000000000000000000000000000000000..4b02ec7727740eaa9ca70a7f0ca64df94fff4c3a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/falcon.py @@ -0,0 +1,143 @@ +# Copyright (c) 2023, Tri Dao. + +import math +import re +from collections import OrderedDict + +import torch +import torch.nn.functional as F +from einops import rearrange +from transformers import FalconConfig, GPT2Config + + +def remap_state_dict_hf_falcon(state_dict, config): + def key_mapping_layers(key): + return re.sub(r"^transformer.h.", "transformer.layers.", key) + + state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items()) + # Word embedding + def key_mapping_emb(key): + return re.sub( + r"^transformer.word_embeddings.", "transformer.embeddings.word_embeddings.", key + ) + + state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items()) + word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight") + # It's possible that vocab_size is padded to be a multiple of 8, for example. + pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) + vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple + state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad( + word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) + ) + if getattr(config, "tie_word_embeddings"): + state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"] + else: + output_embeddings = state_dict.pop("lm_head.weight") + # It's possible that vocab_size is padded to be a multiple of 8, for example. + state_dict["lm_head.weight"] = F.pad( + output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0]) + ) + output_embeddings_bias = state_dict.pop("lm_head.bias") + state_dict["lm_head.bias"] = F.pad( + output_embeddings_bias, (0, vocab_size - output_embeddings_bias.shape[0]) + ) + + # LayerNorm + def key_mapping_ln(key): + key = re.sub( + r"^transformer.layers.(\d+).input_layernorm.", r"transformer.layers.\1.norm1.", key + ) + key = re.sub( + r"^transformer.layers.(\d+).post_attention_layernorm.", + r"transformer.layers.\1.norm2.", + key, + ) + key = re.sub(r"^transformer.layers.(\d+).ln_attn.", r"transformer.layers.\1.norm1.", key) + key = re.sub(r"^transformer.layers.(\d+).ln_mlp.", r"transformer.layers.\1.norm2.", key) + return key + + state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) + + # MLP + def key_mapping_mlp(key): + key = re.sub( + r"^transformer.layers.(\d+).mlp.dense_h_to_4h.", r"transformer.layers.\1.mlp.fc1.", key + ) + key = re.sub( + r"^transformer.layers.(\d+).mlp.dense_4h_to_h.", r"transformer.layers.\1.mlp.fc2.", key + ) + return key + + state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) + + def key_mapping_attn(key): + key = re.sub( + r"^transformer.layers.(\d+).self_attention.query_key_value.", + r"transformer.layers.\1.mixer.Wqkv.", + key, + ) + key = re.sub( + r"^transformer.layers.(\d+).self_attention.dense.", + r"transformer.layers.\1.mixer.out_proj.", + key, + ) + return key + + state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) + n_head = config.n_head + n_head_kv = getattr(config, "n_head_kv", 1) + headdim = config.hidden_size // n_head + for l in range(config.n_layer): + # The weights are stored in a different layout compared to our implementation + Wqkv = rearrange( + state_dict.pop(f"transformer.layers.{l}.mixer.Wqkv.weight"), + "(group ratio headdim) ... -> group ratio headdim ...", + ratio=n_head // n_head_kv + 2, + headdim=headdim, + ) + Wq = rearrange(Wqkv[:, :-2], "group ratio headdim ... -> (group ratio headdim) ...") + Wk = rearrange(Wqkv[:, [-2]], "group ratio headdim ... -> (group ratio headdim) ...") + Wv = rearrange(Wqkv[:, [-1]], "group ratio headdim ... -> (group ratio headdim) ...") + state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat([Wq, Wk, Wv], dim=0) + + return state_dict + + +def falcon_config_to_gpt2_config(falcon_config: FalconConfig) -> GPT2Config: + # The 40b config uses "n_head_kv" instead of "num_kv_heads" + n_head_kv = getattr( + falcon_config, + "n_head_kv", + 1 if getattr(falcon_config, "multi_query", False) else falcon_config.n_head, + ) + # HACK: the 40b config has 2 LN per layer instead of 1, but that's not reflected in the config. + # So we have to infer it from the number of heads in the key/value block + parallel_block_tied_norm = n_head_kv == 1 + return GPT2Config( + vocab_size=falcon_config.vocab_size, + n_positions=0, # No absolute position embedding + n_embd=falcon_config.hidden_size, + n_layer=falcon_config.n_layer, + n_head=falcon_config.n_head, + n_inner=falcon_config.hidden_size * 4, + activation_function="gelu", + resid_pdrop=falcon_config.hidden_dropout, + embd_pdrop=0.0, # There doesn't seem to be any embedding dropout + attn_pdrop=falcon_config.attention_dropout, + layer_norm_epsilon=falcon_config.layer_norm_epsilon, + initializer_range=falcon_config.initializer_range, + bos_token_id=falcon_config.bos_token_id, + eos_token_id=falcon_config.eos_token_id, + # These are new arguments not in the original GPT2Config + parallel_block=falcon_config.parallel_attn, + n_head_kv=n_head_kv, + parallel_block_tied_norm=parallel_block_tied_norm, + rotary_emb_fraction=1.0, + rotary_emb_interleaved=False, + tie_word_embeddings=True, + qkv_proj_bias=falcon_config.bias, + out_proj_bias=falcon_config.bias, + mlp_fc1_bias=falcon_config.bias, + mlp_fc2_bias=falcon_config.bias, + lm_head_bias=False, + ) diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/gpt.py b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/gpt.py new file mode 100644 index 0000000000000000000000000000000000000000..3539f8f901695b29454358972d65031f4c4fabeb --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/gpt.py @@ -0,0 +1,1080 @@ +# Copyright (c) 2024, Tri Dao. + +import logging +import math +import re +from collections import OrderedDict, namedtuple +from collections.abc import Sequence +from functools import partial +from typing import Dict, List + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from transformers import GPT2Config + +from flash_attn.models.bigcode import remap_state_dict_hf_bigcode +from flash_attn.models.falcon import remap_state_dict_hf_falcon +from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox +from flash_attn.models.gptj import remap_state_dict_hf_gptj +from flash_attn.models.llama import remap_state_dict_hf_llama +from flash_attn.models.opt import remap_state_dict_hf_opt +from flash_attn.modules.block import Block, ParallelBlock +from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings +from flash_attn.modules.mha import MHA, ParallelMHA +from flash_attn.modules.mlp import ( + FusedMLP, + GatedMlp, + Mlp, + ParallelFusedMLP, + ParallelGatedMlp, + ParallelMLP, +) +from flash_attn.ops.activations import sqrelu_fwd +from flash_attn.utils.distributed import ( + all_gather, + all_gather_raw, + get_dim_for_local_rank, + sync_shared_params, +) +from flash_attn.utils.generation import GenerationMixin +from flash_attn.utils.pretrained import state_dict_from_pretrained + +try: + from flash_attn.ops.fused_dense import ColumnParallelLinear +except ImportError: + ColumnParallelLinear = None + +try: + from flash_attn.ops.triton.mlp import FusedDenseSqreluDense +except ImportError: + FusedDenseSqreluDense = None + +try: + from flash_attn.ops.triton.layer_norm import layer_norm_fn, RMSNorm +except ImportError: + layer_norm_fn, RMSNorm = None, None + +logger = logging.getLogger(__name__) + + +def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dtype=None): + factory_kwargs = {"device": device, "dtype": dtype} + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + attn_scale_power = 0.5 if not getattr(config, "mup_scale_qk_dot_by_d", False) else 1.0 + softmax_scale = 1.0 if not config.scale_attn_weights else (head_dim ** (-attn_scale_power)) + softmax_scale *= getattr(config, "mup_attn_multiplier", 1.0) + if config.scale_attn_by_inverse_layer_idx: + assert layer_idx is not None + softmax_scale /= float(layer_idx + 1) + dwconv = getattr(config, "attn_dwconv", False) + if dwconv: + assert process_group is None, "TensorParallel MHA does not support dwconv yet" + qkv_proj_bias = getattr(config, "qkv_proj_bias", True) + out_proj_bias = getattr(config, "out_proj_bias", True) + rotary_emb_dim = int(getattr(config, "rotary_emb_fraction", 0.0) * head_dim) + rotary_emb_base = getattr(config, "rotary_emb_base", 10000.0) + rotary_emb_scale_base = getattr(config, "rotary_emb_scale_base", None) + rotary_emb_interleaved = getattr(config, "rotary_emb_interleaved", False) + use_alibi = getattr(config, "use_alibi", False) + window_size = getattr(config, "window_size", (-1, -1)) + use_flash_attn = getattr(config, "use_flash_attn", False) + fused_bias_fc = getattr(config, "fused_bias_fc", False) + if not fused_bias_fc: + assert process_group is None, "TensorParallel MHA requires fused_bias_fc" + mha_cls = MHA if process_group is None else ParallelMHA + serial_kwargs = ( + {"fused_bias_fc": fused_bias_fc, "dwconv": dwconv} if process_group is None else {} + ) + parallel_kwargs = ( + { + "process_group": process_group, + "sequence_parallel": getattr(config, "sequence_parallel", True), + } + if process_group is not None + else {} + ) + num_heads_kv = getattr(config, "n_head_kv", None) + mixer_cls = partial( + mha_cls, + num_heads=config.num_attention_heads, + num_heads_kv=num_heads_kv, + qkv_proj_bias=qkv_proj_bias, + out_proj_bias=out_proj_bias, + dropout=config.attn_pdrop, + softmax_scale=softmax_scale, + causal=True, + layer_idx=layer_idx, + rotary_emb_dim=rotary_emb_dim, + rotary_emb_base=rotary_emb_base, + rotary_emb_scale_base=rotary_emb_scale_base, + rotary_emb_interleaved=rotary_emb_interleaved, + use_alibi=use_alibi, + window_size=window_size, + use_flash_attn=use_flash_attn, + **serial_kwargs, + **parallel_kwargs, + **factory_kwargs, + ) + return mixer_cls + + +def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtype=None): + factory_kwargs = {"device": device, "dtype": dtype} + mlp_fc1_bias = getattr(config, "mlp_fc1_bias", True) + mlp_fc2_bias = getattr(config, "mlp_fc2_bias", True) + fused_mlp = getattr(config, "fused_mlp", False) + if fused_mlp: + assert config.activation_function in [ + "gelu_new", + "gelu_fast", + "gelu_approx", + "gelu_pytorch_tanh", + "relu", + "sqrelu", + ] + fused_dense_sqrelu_dense = getattr(config, "fused_dense_sqrelu_dense", False) + if fused_dense_sqrelu_dense: + assert config.activation_function == "sqrelu", ( + "fused_dense_sqrelu_dense only " "supports approximate activation_function sqrelu" + ) + assert not (fused_dense_sqrelu_dense and fused_mlp) + if not fused_mlp and not fused_dense_sqrelu_dense: + assert config.activation_function in [ + "gelu", + "gelu_new", + "gelu_fast", + "gelu_approx", + "gelu_pytorch_tanh", + "relu", + "sqrelu", + "glu", + "swiglu", + "geglu", + ] + if config.activation_function in ["glu", "swiglu", "geglu"]: + activation = ( + F.sigmoid + if config.activation_function == "glu" + else (F.silu if config.activation_function == "swiglu" else F.gelu) + ) + mlp_cls = GatedMlp if process_group is None else ParallelGatedMlp + parallel_kwargs = ( + { + "process_group": process_group, + "sequence_parallel": getattr(config, "sequence_parallel", True), + } + if process_group is not None + else {} + ) + mlp_multiple_of = getattr(config, "mlp_multiple_of", 128) + mlp_cls = partial( + mlp_cls, + hidden_features=config.n_inner, + activation=activation, + bias1=mlp_fc1_bias, + bias2=mlp_fc2_bias, + multiple_of=mlp_multiple_of, + **parallel_kwargs, + **factory_kwargs, + ) + else: + if config.activation_function == "relu": + activation = partial(F.relu, inplace=True) + elif config.activation_function == "sqrelu": + activation = sqrelu_fwd + else: + approximate = ( + "tanh" + if config.activation_function + in ["gelu_new", "gelu_fast", "gelu_approx", "gelu_pytorch_tanh"] + else "none" + ) + activation = partial(F.gelu, approximate=approximate) + mlp_cls = Mlp if process_group is None else ParallelMLP + parallel_kwargs = ( + { + "process_group": process_group, + "sequence_parallel": getattr(config, "sequence_parallel", True), + } + if process_group is not None + else {} + ) + mlp_cls = partial( + mlp_cls, + hidden_features=config.n_inner, + activation=activation, + bias1=mlp_fc1_bias, + bias2=mlp_fc2_bias, + **parallel_kwargs, + **factory_kwargs, + ) + else: + mlp_checkpoint_lvl = getattr(config, "mlp_checkpoint_lvl", 0) + # mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer + if isinstance(mlp_checkpoint_lvl, Sequence): + assert layer_idx is not None + mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx] + if fused_mlp: + if FusedMLP is None: + raise ImportError("fused_dense is not installed") + activation = ( + "gelu_approx" + if config.activation_function + in ["gelu_new", "gelu_fast", "gelu_approx", "gelu_pytorch_tanh"] + else config.activation_function + ) + mlp_cls = FusedMLP if process_group is None else ParallelFusedMLP + parallel_kwargs = ( + { + "process_group": process_group, + "sequence_parallel": getattr(config, "sequence_parallel", True), + } + if process_group is not None + else {} + ) + mlp_cls = partial( + mlp_cls, + hidden_features=config.n_inner, + activation=activation, + checkpoint_lvl=mlp_checkpoint_lvl, + bias1=mlp_fc1_bias, + bias2=mlp_fc2_bias, + **parallel_kwargs, + **factory_kwargs, + ) + elif fused_dense_sqrelu_dense: + if process_group is not None: + assert fused_mlp, "Tensor Parallel is not implemented for FusedDenseSqreluDense" + assert FusedDenseSqreluDense is not None + mlp_cls = partial( + FusedDenseSqreluDense, + hidden_features=config.n_inner, + checkpoint_lvl=mlp_checkpoint_lvl, + **factory_kwargs, + ) + else: + raise RuntimeError("MLP type not supported") + return mlp_cls + + +def create_block(config, layer_idx=None, process_group=None, device=None, dtype=None): + factory_kwargs = {"device": device, "dtype": dtype} + sequence_parallel = getattr(config, "sequence_parallel", True) + mixer_cls = create_mixer_cls(config, layer_idx, process_group=process_group, **factory_kwargs) + mlp_cls = create_mlp_cls(config, layer_idx, process_group=process_group, **factory_kwargs) + use_rms_norm = getattr(config, "rms_norm", False) + norm_cls = partial( + nn.LayerNorm if not use_rms_norm else RMSNorm, + eps=config.layer_norm_epsilon, + **factory_kwargs, + ) + # TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable + residual_in_fp32 = getattr(config, "residual_in_fp32", False) + resid_dropout1 = config.resid_pdrop if layer_idx is None or layer_idx > 0 else config.embd_pdrop + prenorm = getattr(config, "prenorm", True) + parallel_block = getattr(config, "parallel_block", False) + if not parallel_block: + block = Block( + config.hidden_size, + mixer_cls, + mlp_cls, + norm_cls=norm_cls, + prenorm=prenorm, + resid_dropout1=resid_dropout1, + resid_dropout2=config.resid_pdrop, + fused_dropout_add_ln=getattr(config, "fused_dropout_add_ln", False), + residual_in_fp32=residual_in_fp32, + sequence_parallel=sequence_parallel and process_group is not None, + mark_shared_params=process_group is not None, + ) + else: + assert prenorm + block = ParallelBlock( + config.hidden_size, + mixer_cls, + mlp_cls, + norm_cls=norm_cls, + resid_dropout1=resid_dropout1, + resid_dropout2=config.resid_pdrop, + tied_norm=getattr(config, "parallel_block_tied_norm", False), + fused_dropout_add_ln=getattr(config, "fused_dropout_add_ln", False), + residual_in_fp32=residual_in_fp32, + sequence_parallel=sequence_parallel and process_group is not None, + mark_shared_params=process_group is not None, + ) + block.layer_idx = layer_idx + return block + + +class GPTPreTrainedModel(nn.Module): + """An abstract class to handle weights initialization and + a simple interface for dowloading and loading pretrained models. + """ + + def __init__(self, config, *inputs, **kwargs): + super().__init__() + if not isinstance(config, GPT2Config): + raise ValueError( + "Parameter config in `{}(config)` should be an instance of class `GPT2Config`. " + "To create a model from a Google pretrained model use " + "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( + self.__class__.__name__, self.__class__.__name__ + ) + ) + self.config = config + + @classmethod + def from_pretrained( + cls, + model_name, + config, + *args, + strict=True, + device=None, + dtype=None, + world_size=1, + rank=0, + **kwargs, + ): + """ + Instantiate a GPTPreTrainedModel from a pre-trained model file or a pytorch state dict. + Download and cache the pre-trained model file if needed. + """ + # Instantiate model. + model = cls(config, *args, device=device, dtype=dtype, **kwargs) + # Load state_dict in cpu because we already initialized the model in GPU, and we don't + # want extra stuff taking up more GPU memory + state_dict = state_dict_from_pretrained(model_name, device="cpu", dtype=dtype) + if model_name.startswith("gpt2"): + state_dict = remap_state_dict_hf_gpt2(state_dict, config) + elif model_name.startswith("facebook/opt"): + state_dict = remap_state_dict_hf_opt(state_dict, config) + elif model_name.startswith("EleutherAI/gpt-j-") or model_name.startswith( + "togethercomputer/GPT-JT-" + ): + state_dict = remap_state_dict_hf_gptj(state_dict, config) + elif ( + model_name.startswith("EleutherAI/gpt-neox-") + or model_name.startswith("EleutherAI/pythia-") + or model_name.startswith("togethercomputer/RedPajama-INCITE-") + ): + state_dict = remap_state_dict_hf_gpt_neox(state_dict, config) + elif model_name.startswith("tiiuae/falcon-"): + state_dict = remap_state_dict_hf_falcon(state_dict, config) + elif model_name.startswith("meta-llama/Llama-"): + state_dict = remap_state_dict_hf_llama(state_dict, config) + elif model_name.startswith("bigcode/") or model_name.startswith("WizardLM/"): + state_dict = remap_state_dict_hf_bigcode(state_dict, config) + else: + raise NotImplementedError(f"Model {model_name} not supported") + if world_size > 1: + state_dict = shard_state_dict_tp(state_dict, config, world_size, rank) + load_return = model.load_state_dict(state_dict, strict=strict) + logger.info(load_return) + return model + + +# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454 +def _init_weights( + module, n_layer, initializer_range=0.02, mup_width_scale=1.0, rescale_prenorm_residual=True +): + mup_init_scale = math.sqrt(mup_width_scale) + if isinstance(module, nn.Linear): + nn.init.normal_(module.weight, std=initializer_range * mup_init_scale) + optim_cfg = getattr(module.weight, "_optim", {}) + optim_cfg.update({"lr_multiplier": mup_width_scale}) + setattr(module.weight, "_optim", optim_cfg) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, std=initializer_range) + + if rescale_prenorm_residual: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name in ["out_proj.weight", "fc2.weight"]: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + nn.init.normal_( + p, mean=0.0, std=initializer_range * mup_init_scale / math.sqrt(2 * n_layer) + ) + + +class GPTModel(GPTPreTrainedModel): + def __init__(self, config: GPT2Config, process_group=None, device=None, dtype=None): + super().__init__(config) + factory_kwargs = {"device": device, "dtype": dtype} + self.process_group = process_group + self.sequence_parallel = getattr(config, "sequence_parallel", True) + assert config.activation_function in [ + "gelu", + "gelu_new", + "gelu_fast", + "gelu_approx", + "gelu_pytorch_tanh", + "relu", + "sqrelu", + "glu", + "swiglu", + "geglu", + ] + pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) + vocab_size = ( + math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple + ) + self.embeddings_multiplier = getattr(config, "mup_embeddings_multiplier", 1.0) + # TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable + self.residual_in_fp32 = getattr(config, "residual_in_fp32", False) + # These 2 options are for OPT-350m + self.prenorm = getattr(config, "prenorm", True) + use_rms_norm = getattr(config, "rms_norm", False) + word_embed_proj_dim = getattr(config, "word_embed_proj_dim", None) + # For GPT-J, GPT-NeoX + self.parallel_block = getattr(config, "parallel_block", False) + + if process_group is None: + self.embeddings = GPT2Embeddings( + config.hidden_size, + vocab_size, + config.max_position_embeddings, + word_embed_proj_dim=word_embed_proj_dim, + **factory_kwargs, + ) + else: + self.embeddings = ParallelGPT2Embeddings( + config.hidden_size, + vocab_size, + config.max_position_embeddings, + process_group=process_group, + sequence_parallel=self.sequence_parallel, + **factory_kwargs, + ) + + # We change the order of dropout, residual and layer norm: + # Instead of LN -> Attn / MLP -> Dropout -> Add, we do: + # Dropout -> Add -> LN -> Attn / MLP, returning both the residual branch (output of Add) and + # the main branch (output of MLP). The model definition is unchanged, but the mapping of the + # nn.Dropout probabilities are changed. + # This is for performance reason: we can fuse dropout + add + layer_norm. + self.layers = nn.ModuleList( + [ + create_block(config, layer_idx=i, process_group=process_group, **factory_kwargs) + for i in range(config.num_hidden_layers) + ] + ) + rotary_emb_fraction = getattr(config, "rotary_emb_fraction", 0.0) + if rotary_emb_fraction > 0.0: # Tie all the RotaryEmbedding modules to share the same cos/sin cache + for layer in self.layers[1:]: + layer.mixer.rotary_emb = self.layers[0].mixer.rotary_emb + + self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False) + if self.fused_dropout_add_ln: + if layer_norm_fn is None: + raise ImportError("Triton is not installed") + if self.prenorm: + self.drop_f = nn.Dropout(config.resid_pdrop) + norm_cls = nn.LayerNorm if not use_rms_norm else RMSNorm + self.ln_f = norm_cls( + config.hidden_size, eps=config.layer_norm_epsilon, **factory_kwargs + ) + if process_group is not None: + for p in self.ln_f.parameters(): + # Mark the norm parameters as "shared_params" so that we sync their values at init. + p._shared_params = True + # Mark the norm params as "sequence_parallel" so we run all-reduce on their grads. + if self.sequence_parallel: + p._sequence_parallel = True + + self.apply( + partial( + _init_weights, + n_layer=config.num_hidden_layers, + initializer_range=config.initializer_range, + mup_width_scale=getattr(config, "mup_width_scale", 1.0), + ) + ) + self.tie_weights() + + def tie_weights(self): + if self.process_group is not None: + sync_shared_params(self, self.process_group) + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + return { + i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) + for i, layer in enumerate(self.layers) + } + + def forward(self, input_ids, position_ids=None, inference_params=None): + # If using Tensor Parallel with sequence parallel, we combine the batch and the seqlen + # dimensions so that we can split on it easily, in case of small batch size. + # Only the attention layers need to know the seqlen. + embedding_kwargs = ( + {"combine_batch_seqlen_dim": True} + if self.process_group is not None and self.sequence_parallel + else {} + ) + hidden_states = self.embeddings(input_ids, position_ids=position_ids, **embedding_kwargs) + if self.embeddings_multiplier != 1.0: + hidden_states = hidden_states * self.embeddings_multiplier + if self.parallel_block: + hidden_states2 = None + residual = None + mixer_kwargs = ( + {"seqlen": input_ids.shape[1]} + if self.process_group is not None and self.sequence_parallel + else {} + ) + if inference_params is not None: + mixer_kwargs["inference_params"] = inference_params + for layer in self.layers: + if self.prenorm: + if not self.parallel_block: + hidden_states, residual = layer( + hidden_states, residual, mixer_kwargs=mixer_kwargs + ) + else: + hidden_states, hidden_states2, residual = layer( + hidden_states, hidden_states2, residual, mixer_kwargs=mixer_kwargs + ) + else: + hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs) + if self.prenorm: + if not self.fused_dropout_add_ln: + dropped = self.drop_f(hidden_states) + if not self.parallel_block: + residual = (dropped + residual) if residual is not None else dropped + else: + dropped2 = self.drop_f(hidden_states2) + residual = ( + (residual + dropped + dropped2) + if residual is not None + else dropped + dropped2 + ) + hidden_states = self.ln_f(residual.to(dtype=self.ln_f.weight.dtype)) + else: + # Set prenorm=False here since we don't need the residual + hidden_states = layer_norm_fn( + hidden_states, + self.ln_f.weight, + self.ln_f.bias, + residual=residual, + x1=None if not self.parallel_block else hidden_states2, + eps=self.ln_f.eps, + dropout_p=self.drop_f.p if self.training else 0.0, + prenorm=False, + is_rms_norm=isinstance(self.ln_f, RMSNorm) + ) + return hidden_states + + +class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin): + def __init__(self, config: GPT2Config, process_group=None, device=None, dtype=None): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__(config) + self.process_group = process_group + self.transformer = GPTModel(config, process_group=process_group, **factory_kwargs) + self.tie_word_embeddings = getattr(config, "tie_word_embeddings", True) + lm_head_bias = getattr(config, "lm_head_bias", False) + pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) + vocab_size = ( + math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple + ) + # This option is for OPT-350m + word_embed_proj_dim = getattr(config, "word_embed_proj_dim", None) + embed_dim = config.n_embd if word_embed_proj_dim is None else word_embed_proj_dim + if word_embed_proj_dim is not None: + self.project_out = nn.Linear(config.n_embd, embed_dim, bias=False, **factory_kwargs) + else: + self.project_out = None + mup_width_scale = getattr(config, "mup_width_scale", 1.0) + mup_output_multiplier = getattr(config, "mup_output_multiplier", 1.0) + self.output_scale = mup_output_multiplier * mup_width_scale + if process_group is None: + self.lm_head = nn.Linear(embed_dim, vocab_size, bias=lm_head_bias, **factory_kwargs) + else: + if ColumnParallelLinear is None: + raise ImportError("fused_dense_lib is not installed") + self.lm_head = ColumnParallelLinear( + embed_dim, + vocab_size, + process_group, + bias=lm_head_bias, + sequence_parallel=getattr(config, "sequence_parallel", True), + **factory_kwargs, + ) + self.norm_head = getattr(config, "norm_head", False) + # Initialize weights and apply final processing + self.apply( + partial( + _init_weights, + n_layer=config.num_hidden_layers, + initializer_range=config.initializer_range, + mup_width_scale=mup_width_scale, + ) + ) + self.tie_weights() + + def tie_weights(self): + if self.tie_word_embeddings: + self.lm_head.weight = self.transformer.embeddings.word_embeddings.weight + if self.process_group is not None: + sync_shared_params(self, self.process_group) + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + return self.transformer.allocate_inference_cache( + batch_size, max_seqlen, dtype=dtype, **kwargs + ) + + def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0): + """ + input_ids: (batch, seqlen) int tensor + inference_params: for generation. Adapted from Megatron-LM (and Apex) + https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470 + num_last_tokens: if > 0, only return the logits for the last n tokens + """ + assert ( + input_ids.ndim == 2 + ), f"Expected `input_ids` to have shape [b, slen], but got shape {input_ids.shape}" + b, slen = input_ids.shape + hidden_states = self.transformer( + input_ids, position_ids=position_ids, inference_params=inference_params + ) + if inference_params is not None: + assert hidden_states.ndim == 3, "sequence_parallel is not supported in generation mode" + if num_last_tokens > 0: + hidden_states = hidden_states[:, -num_last_tokens:] + if self.project_out is not None: + hidden_states = self.project_out(hidden_states) + if self.output_scale != 1.0: + hidden_states = hidden_states * self.output_scale + if not self.norm_head: + lm_logits = self.lm_head(hidden_states) + else: + lm_head_weight = F.normalize(self.lm_head.weight) + if isinstance(self.lm_head, ColumnParallelLinear) and self.lm_head.sequence_parallel: + hidden_states = all_gather(hidden_states, self.lm_head.process_group) + lm_logits = F.linear(hidden_states, lm_head_weight, bias=self.lm_head.bias) + # During inference, we want the full logit for sampling + if isinstance(self.lm_head, ColumnParallelLinear) and inference_params is not None: + lm_logits, _ = all_gather_raw(lm_logits, self.lm_head.process_group) + lm_logits = rearrange(lm_logits, "(n b) ... d -> b ... (n d)", b=b) + CausalLMOutput = namedtuple("CausalLMOutput", ["logits"]) + return CausalLMOutput(logits=lm_logits) + + def load_state_dict(self, state_dict, strict=True): + # Remapping from our checkpoints that used a different ordering of layers in the block + # Previous: Attn / MLP -> Dropout -> Add -> LN + # Current: Dropout -> Add -> LN -> Attn / MLP + if "transformer.ln_0.weight" in state_dict: + n_layers = len(self.transformer.layers) + ln_weight = state_dict.pop(f"transformer.layers.{n_layers - 1}.norm2.weight") + ln_bias = state_dict.pop(f"transformer.layers.{n_layers - 1}.norm2.bias") + state_dict["transformer.ln_f.weight"] = ln_weight + state_dict["transformer.ln_f.bias"] = ln_bias + for l in reversed(range(n_layers)): + ln_weight = state_dict.pop(f"transformer.layers.{l}.norm1.weight") + ln_bias = state_dict.pop(f"transformer.layers.{l}.norm1.bias") + state_dict[f"transformer.layers.{l}.norm2.weight"] = ln_weight + state_dict[f"transformer.layers.{l}.norm2.bias"] = ln_bias + if l > 0: + ln_weight = state_dict.pop(f"transformer.layers.{l - 1}.norm2.weight") + ln_bias = state_dict.pop(f"transformer.layers.{l - 1}.norm2.bias") + state_dict[f"transformer.layers.{l}.norm1.weight"] = ln_weight + state_dict[f"transformer.layers.{l}.norm1.bias"] = ln_bias + ln_weight = state_dict.pop("transformer.ln_0.weight") + ln_bias = state_dict.pop("transformer.ln_0.bias") + state_dict[f"transformer.layers.0.norm1.weight"] = ln_weight + state_dict[f"transformer.layers.0.norm1.bias"] = ln_bias + return super().load_state_dict(state_dict, strict=strict) + + +def shard_state_dict_tp(state_dict, config, world_size, rank): + """Convert the state_dict of a standard GPT model to the state_dict of a GPT model + with tensor parallel. + + This function modifies state_dict in place. + """ + pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) + vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple + assert vocab_size % world_size == 0 + assert config.hidden_size % world_size == 0 + inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size + assert inner_dim % world_size == 0 + + n_head = config.n_head + n_head_kv = getattr(config, "n_head_kv", n_head) + + embed_dim = config.hidden_size + head_dim = embed_dim // n_head + + def shard_first_dim(state_dict, key): + if key in state_dict: + x = state_dict[key] + dim = x.shape[0] // world_size + state_dict[key] = x[rank * dim : (rank + 1) * dim] + + def shard_last_dim(state_dict, key, multiple_of=1): + if key in state_dict: + x = state_dict[key] + dim_each_rank = [ + get_dim_for_local_rank(x.size(-1), world_size, local_rank, multiple_of) + for local_rank in range(world_size) + ] + beg, end = tuple(sum(dim_each_rank[:pos]) for pos in (rank, rank + 1)) + state_dict[key] = x[..., beg:end] + + def shard_gatedmlp_fc1_dim(state_dict, key): + if key in state_dict: + x = state_dict[key] + dim = x.shape[0] // world_size // 2 + state_dict[key] = rearrange( + rearrange(x, "(two o) ... -> two o ...", two=2)[:, rank * dim : (rank + 1) * dim], + "two o ... -> (two o) ...", + ) + + def shard_qkv_headdim(state_dict, key): + if key in state_dict: + n_head_each_rank = [ + get_dim_for_local_rank(n_head, world_size, local_rank) + for local_rank in range(world_size) + ] + n_head_kv_each_rank = [ + get_dim_for_local_rank(n_head_kv, world_size, local_rank) + for local_rank in range(world_size) + ] + + beg_n_head = sum(n_head_each_rank[:rank]) + end_n_head = sum(n_head_each_rank[: rank + 1]) + + beg_n_head_kv = sum(n_head_kv_each_rank[:rank]) + end_n_head_kv = sum(n_head_kv_each_rank[: rank + 1]) + + if n_head_kv == n_head: + x = rearrange(state_dict[key], "(three d) ... -> three d ...", three=3) + state_dict[key] = rearrange( + x[:, beg_n_head * head_dim : end_n_head * head_dim], + "three d ... -> (three d) ...", + ) + else: + x = rearrange( + state_dict[key], + "(nheadqkv headdim) ... -> nheadqkv headdim ...", + nheadqkv=n_head + 2 * n_head_kv, + ) + state_dict[key] = rearrange( + torch.cat( + [ + x[beg_n_head:end_n_head], + x[n_head + beg_n_head_kv : n_head + end_n_head_kv], + x[ + n_head + + n_head_kv + + beg_n_head_kv : n_head + + n_head_kv + + end_n_head_kv + ], + ], + dim=0, + ), + "nheadqkv headdim ... -> (nheadqkv headdim) ...", + ) + + shard_first_dim(state_dict, "transformer.embeddings.word_embeddings.weight") + if "lm_head.weight" in state_dict: + shard_first_dim(state_dict, "lm_head.weight") + if "transformer.embeddings.position_embeddings.weight" in state_dict: + shard_last_dim(state_dict, "transformer.embeddings.position_embeddings.weight") + for i in range(config.num_hidden_layers): + shard_qkv_headdim(state_dict, f"transformer.layers.{i}.mixer.Wqkv.weight") + shard_qkv_headdim(state_dict, f"transformer.layers.{i}.mixer.Wqkv.bias") + shard_last_dim( + state_dict, f"transformer.layers.{i}.mixer.out_proj.weight", multiple_of=head_dim + ) + if rank != 0: + state_dict.pop(f"transformer.layers.{i}.mixer.out_proj.bias", None) + if config.activation_function in ["glu", "swiglu", "geglu"]: + shard_gatedmlp_fc1_dim(state_dict, f"transformer.layers.{i}.mlp.fc1.weight") + shard_gatedmlp_fc1_dim(state_dict, f"transformer.layers.{i}.mlp.fc1.bias") + else: + shard_first_dim(state_dict, f"transformer.layers.{i}.mlp.fc1.weight") + shard_first_dim(state_dict, f"transformer.layers.{i}.mlp.fc1.bias") + shard_last_dim(state_dict, f"transformer.layers.{i}.mlp.fc2.weight") + if rank != 0: + state_dict.pop(f"transformer.layers.{i}.mlp.fc2.bias", None) + return state_dict + + +def combine_state_dicts_tp(state_dicts: List[Dict[str, torch.Tensor]], config: GPT2Config): + """Convert the list of sharded state_dict of a GPT model with tensor parallel to + the state_dict of a standard GPT model. + + This function is meant to be the "reverse" of shard_state_dict_tp. + + Precondition: + - state_dicts should be ordered in the same way as the shards were created. + """ + world_size = len(state_dicts) + keys = state_dicts[0].keys() + pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) + vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple + assert vocab_size % world_size == 0 + assert config.hidden_size % world_size == 0 + inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size + assert inner_dim % world_size == 0 + assert config.hidden_size % config.n_head == 0 + headdim = config.hidden_size // config.n_head + + # Sometimes the word embeddings are sharded on the 0th dim, sometimes on the 1st dim. + # vocab_size // world_size coordinates are nonzero. + def combine_word_embeddings(state_dicts, state_dict, key): + dim = 0 if state_dicts[0][key].shape[0] == vocab_size // world_size else 1 + state_dict[key] = torch.cat([s[key] for s in state_dicts], dim=dim) + + def combine_dim(state_dicts, state_dict, key, dim=-1): + if key in state_dict: + state_dict[key] = torch.cat([s[key] for s in state_dicts], dim=dim) + + def combine_qkv_headdim(state_dicts, state_dict, key): + n_head = config.n_head + n_head_kv = getattr(config, "n_head_kv", n_head) + if key in state_dict: + if n_head_kv == n_head: + xs = [ + rearrange(s[key], "(three d) ... -> three d ...", three=3) for s in state_dicts + ] + state_dict[key] = rearrange(torch.cat(xs, dim=1), "three d ... -> (three d) ...") + else: + n_head_each_rank = [ + get_dim_for_local_rank(n_head, world_size, local_rank) + for local_rank in range(world_size) + ] + n_head_kv_each_rank = [ + get_dim_for_local_rank(n_head_kv, world_size, local_rank) + for local_rank in range(world_size) + ] + xs = [ + rearrange( + s[key], + "(nheadqkv headdim) ... -> nheadqkv headdim ...", + nheadqkv=rank_n_head + 2 * rank_n_head_kv, + headdim=headdim, + ) + for s, rank_n_head, rank_n_head_kv in zip( + state_dicts, n_head_each_rank, n_head_kv_each_rank + ) + ] + wq = torch.cat([x[: n_head_each_rank[rank]] for rank, x in enumerate(xs)], dim=0) + wk = torch.cat( + [ + x[ + n_head_each_rank[rank] : n_head_each_rank[rank] + + n_head_kv_each_rank[rank] + ] + for rank, x in enumerate(xs) + ], + dim=0, + ) + wv = torch.cat( + [ + x[n_head_each_rank[rank] + n_head_kv_each_rank[rank] :] + for rank, x in enumerate(xs) + ], + dim=0, + ) + wqkv = torch.cat( + [wq, wk, wv], + dim=0, + ) + state_dict[key] = rearrange( + wqkv, + "nheadqkv headdim ... -> (nheadqkv headdim) ...", + ) + + def combine_gated_mlp(state_dicts, state_dict, key): + if key in state_dict: + xs = [rearrange(s[key], "(two d) ... -> two d ...", two=2) for s in state_dicts] + state_dict[key] = rearrange(torch.cat(xs, dim=1), "two d ... -> (two d) ...") + + state_dict = state_dicts[0].copy() # don't modify state_dict[0] inplace + combine_word_embeddings( + state_dicts, state_dict, "transformer.embeddings.word_embeddings.weight" + ) + if "lm_head.weight" in state_dict: + combine_word_embeddings(state_dicts, state_dict, "lm_head.weight") + if "transformer.embeddings.position_embeddings.weight" in state_dict: + combine_dim( + state_dicts, state_dict, "transformer.embeddings.position_embeddings.weight", -1 + ) + mlp_combine_fn = ( + combine_gated_mlp + if config.activation_function in ["glu", "swiglu", "geglu"] + else partial(combine_dim, dim=0) + ) + for i in range(config.num_hidden_layers): + combine_qkv_headdim(state_dicts, state_dict, f"transformer.layers.{i}.mixer.Wqkv.weight") + combine_qkv_headdim(state_dicts, state_dict, f"transformer.layers.{i}.mixer.Wqkv.bias") + combine_dim(state_dicts, state_dict, f"transformer.layers.{i}.mixer.out_proj.weight", -1) + mlp_combine_fn(state_dicts, state_dict, f"transformer.layers.{i}.mlp.fc1.weight") + combine_dim(state_dicts, state_dict, f"transformer.layers.{i}.mlp.fc1.bias", 0) + combine_dim(state_dicts, state_dict, f"transformer.layers.{i}.mlp.fc2.weight", -1) + return state_dict + + +def remap_state_dict_hf_gpt2(state_dict, config): + # Word embedding and position embedding + def key_mapping_pos_emb(key): + return re.sub(r"^wpe.", "transformer.embeddings.position_embeddings.", key) + + state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items()) + word_embeddings = state_dict.pop("wte.weight") + # It's possible that vocab_size is padded to be a multiple of 8, for example. + pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) + vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple + state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad( + word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) + ) + state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"] + + # LayerNorm + def key_mapping_ln(key): + key = re.sub(r"^ln_f.(weight|bias)", r"transformer.ln_f.\1", key) + key = re.sub(r"^h.(\d+).ln_(1|2).(weight|bias)", r"transformer.layers.\1.norm\2.\3", key) + return key + + state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) + + # MLP + for d in range(config.num_hidden_layers): + W1 = state_dict.pop(f"h.{d}.mlp.c_fc.weight") + state_dict[f"transformer.layers.{d}.mlp.fc1.weight"] = W1.t() + W2 = state_dict.pop(f"h.{d}.mlp.c_proj.weight") + state_dict[f"transformer.layers.{d}.mlp.fc2.weight"] = W2.t() + + def key_mapping_mlp(key): + key = re.sub(r"^h.(\d+).mlp.c_fc.bias", r"transformer.layers.\1.mlp.fc1.bias", key) + key = re.sub(r"^h.(\d+).mlp.c_proj.bias", r"transformer.layers.\1.mlp.fc2.bias", key) + return key + + state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) + + # Attention + for d in range(config.num_hidden_layers): + state_dict.pop(f"h.{d}.attn.bias", None) # We don't store this bias + Wqkv = state_dict.pop(f"h.{d}.attn.c_attn.weight") + state_dict[f"transformer.layers.{d}.mixer.Wqkv.weight"] = Wqkv.t() + Wout = state_dict.pop(f"h.{d}.attn.c_proj.weight") + state_dict[f"transformer.layers.{d}.mixer.out_proj.weight"] = Wout.t() + + def key_mapping_attn(key): + key = re.sub(r"^h.(\d+).attn.c_attn.bias", r"transformer.layers.\1.mixer.Wqkv.bias", key) + key = re.sub( + r"^h.(\d+).attn.c_proj.bias", r"transformer.layers.\1.mixer.out_proj.bias", key + ) + return key + + state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) + + return state_dict + + +def remap_state_dict_megatron(state_dict, config): + def key_mapping_transformer(key): + key = re.sub(r"^language_model.encoder.", "transformer.", key) + key = re.sub(r"^language_model.", "transformer.", key) + return key + + state_dict = OrderedDict((key_mapping_transformer(k), v) for k, v in state_dict.items()) + + # Word embedding and position embedding + def key_mapping_pos_emb(key): + return re.sub(r"^wpe.", "transformer.embeddings.position_embeddings.", key) + + state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items()) + word_embeddings = state_dict.pop("transformer.embedding.word_embeddings.weight") + # It's possible that vocab_size is padded to be a multiple of 8, for example. + pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) + vocab_size = ( + math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple) * pad_vocab_size_multiple + ) + state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad( + word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) + ) + state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"] + + # LayerNorm + def key_mapping_ln(key): + key = re.sub(r"^transformer.final_layernorm.(weight|bias)", r"transformer.ln_f.\1", key) + key = re.sub( + r"^transformer.layers.(\d+).input_layernorm.(weight|bias)", + r"transformer.layers.\1.norm1.\2", + key, + ) + key = re.sub( + r"^transformer.layers.(\d+).post_attention_layernorm.(weight|bias)", + r"transformer.layers.\1.norm2.\2", + key, + ) + return key + + state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) + + # MLP + def key_mapping_mlp(key): + key = re.sub( + r"^transformer.layers.(\d+).mlp.dense_h_to_4h.(weight|bias)", + r"transformer.layers.\1.mlp.fc1.\2", + key, + ) + key = re.sub( + r"^transformer.layers.(\d+).mlp.dense_4h_to_h.(weight|bias)", + r"transformer.layers.\1.mlp.fc2.\2", + key, + ) + return key + + state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) + + # Attention + def key_mapping_attn(key): + key = re.sub( + r"^transformer.layers.(\d+).self_attention.rotary_emb.inv_freq", + r"transformer.layers.\1.mixer.rotary_emb.inv_freq", + key, + ) + key = re.sub( + r"^transformer.layers.(\d+).self_attention.query_key_value.(weight|bias)", + r"transformer.layers.\1.mixer.Wqkv.\2", + key, + ) + key = re.sub( + r"^transformer.layers.(\d+).self_attention.dense.(weight|bias)", + r"transformer.layers.\1.mixer.out_proj.\2", + key, + ) + return key + + state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) + # Megatron stores Wqkv as ((nheads 3 headdim), hidden_dim) + # while we store Wqkv as ((3 nheads headdim), hidden_dim) + headdim = config.hidden_size // config.num_attention_heads + for d in range(config.num_hidden_layers): + Wqkv = state_dict.pop(f"transformer.layers.{d}.mixer.Wqkv.weight") + state_dict[f"transformer.layers.{d}.mixer.Wqkv.weight"] = rearrange( + Wqkv, + "(nheads three headdim) ... -> (three nheads headdim) ...", + three=3, + headdim=headdim, + ) + bqkv = state_dict.pop(f"transformer.layers.{d}.mixer.Wqkv.bias") + state_dict[f"transformer.layers.{d}.mixer.Wqkv.bias"] = rearrange( + bqkv, "(nheads three headdim) -> (three nheads headdim)", three=3, headdim=headdim + ) + + return state_dict diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/gpt_neox.py b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/gpt_neox.py new file mode 100644 index 0000000000000000000000000000000000000000..c3894044172260a25c9c561fbaac8add91db5b23 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/gpt_neox.py @@ -0,0 +1,124 @@ +# Copyright (c) 2023, Tri Dao. + +import math +import re +from collections import OrderedDict + +import torch +import torch.nn.functional as F +from einops import rearrange +from transformers import GPT2Config, GPTNeoXConfig + + +def remap_state_dict_hf_gpt_neox(state_dict, config): + def key_mapping_layers(key): + return re.sub(r"^gpt_neox.", "transformer.", key) + + state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items()) + # Word embedding + def key_mapping_emb(key): + return re.sub(r"^transformer.embed_in.", "transformer.embeddings.word_embeddings.", key) + + state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items()) + word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight") + # It's possible that vocab_size is padded to be a multiple of 8, for example. + pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) + vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple + state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad( + word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) + ) + if getattr(config, "tie_word_embeddings", False): + state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"] + else: + output_embeddings = state_dict.pop("embed_out.weight") + # It's possible that vocab_size is padded to be a multiple of 8, for example. + state_dict["lm_head.weight"] = F.pad( + output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0]) + ) + + # LayerNorm + def key_mapping_ln(key): + key = re.sub(r"^transformer.final_layer_norm.", r"transformer.ln_f.", key) + key = re.sub( + r"^transformer.layers.(\d+).input_layernorm.", r"transformer.layers.\1.norm1.", key + ) + key = re.sub( + r"^transformer.layers.(\d+).post_attention_layernorm.", + r"transformer.layers.\1.norm2.", + key, + ) + return key + + state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) + + # MLP + def key_mapping_mlp(key): + key = re.sub( + r"^transformer.layers.(\d+).mlp.dense_h_to_4h.", r"transformer.layers.\1.mlp.fc1.", key + ) + key = re.sub( + r"^transformer.layers.(\d+).mlp.dense_4h_to_h.", r"transformer.layers.\1.mlp.fc2.", key + ) + return key + + state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) + + # Attention + for l in range(config.n_layer): + # We don't store these biases + state_dict.pop(f"transformer.layers.{l}.attention.bias") + state_dict.pop(f"transformer.layers.{l}.attention.masked_bias") + # We don't store these + state_dict.pop(f"transformer.layers.{l}.attention.rotary_emb.inv_freq", None) + # GPT-NeoX stores Wqkv as ((nheads 3 headdim), hidden_dim) + # while we store Wqkv as ((3 nheads headdim), hidden_dim) + headdim = config.hidden_size // config.num_attention_heads + Wqkv = state_dict.pop(f"transformer.layers.{l}.attention.query_key_value.weight") + state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = rearrange( + Wqkv, + "(nheads three headdim) ... -> (three nheads headdim) ...", + three=3, + headdim=headdim, + ) + bqkv = state_dict.pop(f"transformer.layers.{l}.attention.query_key_value.bias") + state_dict[f"transformer.layers.{l}.mixer.Wqkv.bias"] = rearrange( + bqkv, "(nheads three headdim) -> (three nheads headdim)", three=3, headdim=headdim + ) + + def key_mapping_attn(key): + key = re.sub( + r"^transformer.layers.(\d+).attention.dense.", + r"transformer.layers.\1.mixer.out_proj.", + key, + ) + return key + + state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) + + return state_dict + + +def gpt_neox_config_to_gpt2_config(gpt_neox_config: GPTNeoXConfig) -> GPT2Config: + assert gpt_neox_config.rotary_emb_base == 10000 + return GPT2Config( + vocab_size=gpt_neox_config.vocab_size, + n_positions=0, # No absolute position embedding + n_embd=gpt_neox_config.hidden_size, + n_layer=gpt_neox_config.num_hidden_layers, + n_head=gpt_neox_config.num_attention_heads, + n_inner=gpt_neox_config.intermediate_size, + activation_function=gpt_neox_config.hidden_act, + resid_pdrop=0.0, # No dropout + embd_pdrop=0.0, + attn_pdrop=0.0, + layer_norm_epsilon=gpt_neox_config.layer_norm_eps, + initializer_range=gpt_neox_config.initializer_range, + bos_token_id=gpt_neox_config.bos_token_id, + eos_token_id=gpt_neox_config.eos_token_id, + # These are new arguments not in the original GPT2Config + prenorm=True, + parallel_block=gpt_neox_config.use_parallel_residual, + parallel_block_tied_norm=False, + rotary_emb_fraction=gpt_neox_config.rotary_pct, + tie_word_embeddings=gpt_neox_config.tie_word_embeddings, + ) diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/gptj.py b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/gptj.py new file mode 100644 index 0000000000000000000000000000000000000000..ca2330d79ce5b78a1229351956da20d88e356083 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/gptj.py @@ -0,0 +1,109 @@ +# Copyright (c) 2023, Tri Dao. + +import math +import re +from collections import OrderedDict + +import torch +import torch.nn.functional as F +from transformers import GPT2Config, GPTJConfig + + +def remap_state_dict_hf_gptj(state_dict, config): + def key_mapping_layers(key): + return re.sub(r"^transformer.h.", "transformer.layers.", key) + + state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items()) + # Word embedding + def key_mapping_emb(key): + return re.sub(r"^transformer.wte.", "transformer.embeddings.word_embeddings.", key) + + state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items()) + word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight") + # It's possible that vocab_size is padded to be a multiple of 8, for example. + pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) + vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple + state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad( + word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) + ) + if getattr(config, "tie_word_embeddings"): + state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"] + else: + output_embeddings = state_dict.pop("lm_head.weight") + # It's possible that vocab_size is padded to be a multiple of 8, for example. + state_dict["lm_head.weight"] = F.pad( + output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0]) + ) + output_embeddings_bias = state_dict.pop("lm_head.bias") + state_dict["lm_head.bias"] = F.pad( + output_embeddings_bias, (0, vocab_size - output_embeddings_bias.shape[0]) + ) + + # LayerNorm + def key_mapping_ln(key): + return re.sub(r"^transformer.layers.(\d+).ln_1.", r"transformer.layers.\1.norm1.", key) + + state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) + + # MLP + def key_mapping_mlp(key): + key = re.sub( + r"^transformer.layers.(\d+).mlp.fc_in.", r"transformer.layers.\1.mlp.fc1.", key + ) + key = re.sub( + r"^transformer.layers.(\d+).mlp.fc_out.", r"transformer.layers.\1.mlp.fc2.", key + ) + return key + + state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) + + # Attention + for l in range(config.n_layer): + Wq = state_dict.pop(f"transformer.layers.{l}.attn.q_proj.weight") + Wk = state_dict.pop(f"transformer.layers.{l}.attn.k_proj.weight") + Wv = state_dict.pop(f"transformer.layers.{l}.attn.v_proj.weight") + state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat([Wq, Wk, Wv], dim=0) + # We don't store these biases + state_dict.pop(f"transformer.layers.{l}.attn.bias") + state_dict.pop(f"transformer.layers.{l}.attn.masked_bias") + + def key_mapping_attn(key): + return re.sub( + r"^transformer.layers.(\d+).attn.out_proj.", + r"transformer.layers.\1.mixer.out_proj.", + key, + ) + + state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) + + return state_dict + + +def gptj_config_to_gpt2_config(gptj_config: GPTJConfig) -> GPT2Config: + headdim = gptj_config.n_embd // gptj_config.n_head + return GPT2Config( + vocab_size=gptj_config.vocab_size, + n_positions=0, # No absolute position embedding + n_embd=gptj_config.n_embd, + n_layer=gptj_config.n_layer, + n_head=gptj_config.n_head, + n_inner=gptj_config.n_inner, + activation_function=gptj_config.activation_function, + resid_pdrop=gptj_config.resid_pdrop, + embd_pdrop=gptj_config.embd_pdrop, + attn_pdrop=gptj_config.attn_pdrop, + layer_norm_epsilon=gptj_config.layer_norm_epsilon, + initializer_range=gptj_config.initializer_range, + bos_token_id=gptj_config.bos_token_id, + eos_token_id=gptj_config.eos_token_id, + # These are new arguments not in the original GPT2Config + prenorm=True, + parallel_block=True, + parallel_block_tied_norm=True, + rotary_emb_fraction=gptj_config.rotary_dim / headdim, + rotary_emb_interleaved=True, + tie_word_embeddings=False, + qkv_proj_bias=False, + out_proj_bias=False, + lm_head_bias=True, + ) diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/llama.py b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/llama.py new file mode 100644 index 0000000000000000000000000000000000000000..3bfb51d17e27c1eeb5f09293b773cda8f4d81233 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/llama.py @@ -0,0 +1,422 @@ +# Copyright (c) 2023, Tri Dao. + +import json +import math +import os +import re +from collections import OrderedDict +from pathlib import Path +from typing import Dict, List, Union + +import torch +import torch.nn.functional as F +from sentencepiece import SentencePieceProcessor +from transformers import GPT2Config, LlamaConfig + +from einops import rearrange + + +def remap_state_dict_meta_llama( + state_dict: Dict[str, torch.Tensor], config: GPT2Config +) -> Dict[str, torch.Tensor]: + """Convert the state_dict in Meta format to standard GPT format. + + This function modifies state_dict in place. + """ + + def key_mapping_layers(key): + return f"transformer.{key}" if not key.startswith("output.") else key + + state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items()) + + # Word embedding + def key_mapping_emb(key): + return re.sub( + r"^transformer.tok_embeddings.", "transformer.embeddings.word_embeddings.", key + ) + + state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items()) + word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight") + # It's possible that vocab_size is padded to be a multiple of 8, for example. + pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) + vocab_size = ( + math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple) * pad_vocab_size_multiple + ) + state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad( + word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) + ) + if getattr(config, "tie_word_embeddings"): + state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"] + else: + output_embeddings = state_dict.pop("output.weight") + # Need to recompute vocab_size since LLaMa shards the word embeddings and output embeddings + # differently. + vocab_size = ( + math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple) + * pad_vocab_size_multiple + ) + # It's possible that vocab_size is padded to be a multiple of 8, for example. + state_dict["lm_head.weight"] = F.pad( + output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0]) + ) + + # LayerNorm + def key_mapping_ln(key): + key = re.sub(r"^transformer.norm.", r"transformer.ln_f.", key) + key = re.sub( + r"^transformer.layers.(\d+).attention_norm.", + r"transformer.layers.\1.norm1.", + key, + ) + key = re.sub(r"^transformer.layers.(\d+).ffn_norm.", r"transformer.layers.\1.norm2.", key) + return key + + state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) + + # MLP + for l in range(config.n_layer): + w1 = state_dict.pop(f"transformer.layers.{l}.feed_forward.w1.weight") + w3 = state_dict.pop(f"transformer.layers.{l}.feed_forward.w3.weight") + # Our ordering is different + state_dict[f"transformer.layers.{l}.mlp.fc1.weight"] = torch.cat([w3, w1], dim=0) + + def key_mapping_mlp(key): + return re.sub( + r"^transformer.layers.(\d+).feed_forward.w2.", + r"transformer.layers.\1.mlp.fc2.", + key, + ) + + state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) + + # Attention + for l in range(config.n_layer): + Wq = state_dict.pop(f"transformer.layers.{l}.attention.wq.weight") + Wk = state_dict.pop(f"transformer.layers.{l}.attention.wk.weight") + Wv = state_dict.pop(f"transformer.layers.{l}.attention.wv.weight") + state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat([Wq, Wk, Wv], dim=0) + # We don't store these + state_dict.pop(f"transformer.layers.{l}.attention.inner_attention.rope.freqs", None) + + def key_mapping_attn(key): + return re.sub( + r"^transformer.layers.(\d+).attention.wo.", + r"transformer.layers.\1.mixer.out_proj.", + key, + ) + + state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) + + state_dict.pop("transformer.rope.freqs", None) + + return state_dict + + +def remap_state_dict_hf_llama( + state_dict: Dict[str, torch.Tensor], config: GPT2Config +) -> Dict[str, torch.Tensor]: + """Convert the state_dict in Hugging Face format to standard GPT format. + + This function modifies state_dict in place. + """ + + # Embedding + def key_mapping_emb(key): + return re.sub(r"^model.embed_tokens.", "transformer.embeddings.word_embeddings.", key) + + state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items()) + word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight") + # It's possible that vocab_size is padded to be a multiple of 8, for example. + pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) + vocab_size = ( + math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple) * pad_vocab_size_multiple + ) + state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad( + word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) + ) + + # LM head + if getattr(config, "tie_word_embeddings"): + state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"] + else: + output_embeddings = state_dict.pop("lm_head.weight") + # Need to recompute vocab_size since LLaMa shards the word embeddings and output embeddings + # differently. + vocab_size = ( + math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple) + * pad_vocab_size_multiple + ) + # It's possible that vocab_size is padded to be a multiple of 8, for example. + state_dict["lm_head.weight"] = F.pad( + output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0]) + ) + + # MLP + for l in range(config.n_layer): + # Fusing weights this way based on difference in the following: + # https://github.com/huggingface/transformers/blob/b42010bb1d3cbf262d27e0a328661885be46dfdb/src/transformers/models/llama/modeling_llama.py#L220 + # https://github.com/Dao-AILab/flash-attention/blob/c60851a8253257eb970e06a022c82517a8033e8c/flash_attn/modules/mlp.py#L115 + w1 = state_dict.pop(f"model.layers.{l}.mlp.gate_proj.weight") + w3 = state_dict.pop(f"model.layers.{l}.mlp.up_proj.weight") + state_dict[f"transformer.layers.{l}.mlp.fc1.weight"] = torch.cat([w3, w1], dim=0) + + def key_mapping_mlp(key): + return re.sub( + r"^model.layers.(\d+).mlp.down_proj.", + r"transformer.layers.\1.mlp.fc2.", + key, + ) + + state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) + + # LayerNorm + def key_mapping_ln(key): + key = re.sub(r"^model.norm.", r"transformer.ln_f.", key) + key = re.sub( + r"^model.layers.(\d+).input_layernorm.", + r"transformer.layers.\1.norm1.", + key, + ) + key = re.sub( + r"^model.layers.(\d+).post_attention_layernorm.", + r"transformer.layers.\1.norm2.", + key, + ) + return key + + state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) + + def inv_permute(w): + # Inverse of permute implemented in: + # https://github.com/huggingface/transformers/blob/b42010bb1d3cbf262d27e0a328661885be46dfdb/src/transformers/models/llama/convert_llama_weights_to_hf.py#L114 + return rearrange( + w, "(h two d) n -> (h d two) n", d=config.n_embd // config.n_head // 2, two=2 + ) + + # Attention + for l in range(config.n_layer): + Wq = state_dict.pop(f"model.layers.{l}.self_attn.q_proj.weight") + Wk = state_dict.pop(f"model.layers.{l}.self_attn.k_proj.weight") + Wv = state_dict.pop(f"model.layers.{l}.self_attn.v_proj.weight") + + state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat( + [inv_permute(Wq), inv_permute(Wk), Wv], dim=0 + ) + # We don't store these + state_dict.pop(f"model.layers.{l}.self_attn.rotary_emb.inv_freq", None) + + def key_mapping_attn(key): + return re.sub( + r"^model.layers.(\d+).self_attn.o_proj.", + r"transformer.layers.\1.mixer.out_proj.", + key, + ) + + state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) + return state_dict + + +def inv_remap_state_dict_hf_llama( + state_dict: Dict[str, torch.Tensor], config: GPT2Config +) -> Dict[str, torch.Tensor]: + """Convert the state_dict in standard GPT format to Hugging Face format. + + This function is meant to be the inverse of remap_state_dict_hf_llama, up to a + multiplier pad in the embedding and lm_head. That is if the original embedding + isn't a multiple of pad_vocab_size_multiple, then + inv_remap_state_dict_hf_llama(remap_state_dict_hf_llama(state_dict)) != state_dict. + + This function modifies state_dict in place. + """ + + # Embedding + def key_mapping_emb(key): + return re.sub(r"^transformer.embeddings.word_embeddings.", "model.embed_tokens.", key) + + state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items()) + word_embeddings = state_dict.pop("model.embed_tokens.weight") + pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) + vocab_size = ( + math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple) * pad_vocab_size_multiple + ) + state_dict["model.embed_tokens.weight"] = F.pad( + word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) + ) + + # LM head + if getattr(config, "tie_word_embeddings"): + state_dict["lm_head.weight"] = state_dict["model.embed_tokens.weight"] + else: + output_embeddings = state_dict.pop("lm_head.weight") + vocab_size = ( + math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple) + * pad_vocab_size_multiple + ) + state_dict["lm_head.weight"] = F.pad( + output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0]) + ) + + # MLP + for l in range(config.n_layer): + w3, w1 = torch.chunk( + state_dict.pop(f"transformer.layers.{l}.mlp.fc1.weight"), chunks=2, dim=0 + ) + state_dict[f"model.layers.{l}.mlp.gate_proj.weight"] = w1 + state_dict[f"model.layers.{l}.mlp.up_proj.weight"] = w3 + + def key_mapping_mlp(key): + return re.sub( + r"^transformer.layers.(\d+).mlp.fc2.", + r"model.layers.\1.mlp.down_proj.", + key, + ) + + state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) + + # LayerNorm + def key_mapping_ln(key): + key = re.sub(r"^transformer.ln_f.", r"model.norm.", key) + key = re.sub( + r"^transformer.layers.(\d+).norm1.", + r"model.layers.\1.input_layernorm.", + key, + ) + key = re.sub( + r"^transformer.layers.(\d+).norm2.", + r"model.layers.\1.post_attention_layernorm.", + key, + ) + return key + + state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) + + def permute(w): + return rearrange( + w, "(h d two) n -> (h two d) n", d=config.n_embd // config.n_head // 2, two=2 + ) + + n_head = config.n_head + n_head_kv = getattr(config, "n_head_kv", n_head) + + embed_dim = config.hidden_size + head_dim = embed_dim // n_head + + q_dim = n_head * head_dim + k_dim = v_dim = n_head_kv * head_dim + + # Attention + for l in range(config.n_layer): + Wqkv = state_dict.pop(f"transformer.layers.{l}.mixer.Wqkv.weight") + Wq = Wqkv[:q_dim] + Wk = Wqkv[q_dim : q_dim + k_dim] + Wv = Wqkv[q_dim + k_dim : q_dim + k_dim + v_dim] + state_dict[f"model.layers.{l}.self_attn.q_proj.weight"] = permute(Wq) + state_dict[f"model.layers.{l}.self_attn.k_proj.weight"] = permute(Wk) + state_dict[f"model.layers.{l}.self_attn.v_proj.weight"] = Wv + state_dict.pop(f"transformer.layers.{l}.attention.inner_attention.rope.freqs", None) + + def key_mapping_attn(key): + return re.sub( + r"^transformer.layers.(\d+).mixer.out_proj.", + r"model.layers.\1.self_attn.o_proj.", + key, + ) + + state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) + return state_dict + + +def config_from_meta_checkpoint( + checkpoint_path: Union[str, os.PathLike], model_name: str +) -> LlamaConfig: + """Load a LlamaConfig from a checkpoint path.""" + with open(Path(checkpoint_path) / model_name / "params.json") as f: + params = json.load(f) + config = LlamaConfig( + hidden_size=params["dim"], + intermediate_size=None, + num_attention_heads=params["n_heads"], + num_hidden_layers=params["n_layers"], + rms_norm_eps=params["norm_eps"], + num_key_value_heads=params.get("n_kv_heads", None), + ) + multiple_of = params.get("multiple_of", 1) + ffn_dim_multiplier = params.get("ffn_dim_multiplier", None) + + # Compute the hidden dimension of the MLP + # https://github.com/facebookresearch/llama/blob/1a240688810f8036049e8da36b073f63d2ac552c/llama/model.py#L224 + intermediate_size = 4 * config.hidden_size + # https://github.com/facebookresearch/llama/blob/1a240688810f8036049e8da36b073f63d2ac552c/llama/model.py#L195-L199 + intermediate_size = int(2 * intermediate_size / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + intermediate_size = int(ffn_dim_multiplier * intermediate_size) + intermediate_size = multiple_of * ((intermediate_size + multiple_of - 1) // multiple_of) + + config.intermediate_size = intermediate_size + if "rope_theta" in params: + config.rotary_emb_base = params["rope_theta"] + config.vocab_size = 32000 + # some CodeLLaMa have vocab_size 32000, some 32016 + # Sadly it's not specified in the `params.json` file :( + tokenizer = Path(checkpoint_path) / model_name / "tokenizer.model" + if tokenizer.is_file(): + config.vocab_size = SentencePieceProcessor(str(tokenizer)).vocab_size() + return config + + +def config_from_hf_checkpoint( + checkpoint_path: Union[str, os.PathLike], model_name: str +) -> LlamaConfig: + return LlamaConfig.from_pretrained(Path(checkpoint_path) / f"{model_name}-hf" / "config.json") + + +def config_from_checkpoint( + checkpoint_path: Union[str, os.PathLike], model_name: str, checkpoint_format="meta" +) -> LlamaConfig: + if checkpoint_format == "meta": + return config_from_meta_checkpoint(checkpoint_path, model_name) + else: + return config_from_hf_checkpoint(checkpoint_path, model_name) + + +def state_dicts_from_checkpoint( + checkpoint_path: Union[str, os.PathLike], model_name: str +) -> List[dict]: + # Need to sort, otherwise we mess up the ordering and the weights are wrong + return [ + torch.load(path, map_location="cpu") + for path in sorted((Path(checkpoint_path) / model_name).glob("consolidated.*.pth")) + ] + + +def llama_config_to_gpt2_config(llama_config: LlamaConfig) -> GPT2Config: + return GPT2Config( + vocab_size=llama_config.vocab_size, + n_positions=0, # No absolute position embedding + n_embd=llama_config.hidden_size, + n_layer=llama_config.num_hidden_layers, + n_head=llama_config.num_attention_heads, + n_inner=llama_config.intermediate_size, + activation_function="swiglu", # Hardcode since HF calls it 'silu' + # Llama doesn't have dropout, idk if it's because they only release the inference code + resid_pdrop=0.0, + embd_pdrop=0.0, + attn_pdrop=0.0, + layer_norm_epsilon=llama_config.rms_norm_eps, + initializer_range=llama_config.initializer_range, + bos_token_id=llama_config.bos_token_id, + eos_token_id=llama_config.eos_token_id, + # These are new arguments not in the original GPT2Config + pad_token_id=llama_config.pad_token_id, # Idk if this does anything + rms_norm=True, + rotary_emb_fraction=1.0, + rotary_emb_interleaved=True, + tie_word_embeddings=False, + qkv_proj_bias=False, + out_proj_bias=False, + mlp_fc1_bias=False, + mlp_fc2_bias=False, + rotary_emb_base=getattr(llama_config, "rotary_emb_base", 10000.0), + n_head_kv=llama_config.num_key_value_heads, + ) diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/opt.py b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/opt.py new file mode 100644 index 0000000000000000000000000000000000000000..501f9eb6cf44be86aeb77a4e0f35048255850c30 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/opt.py @@ -0,0 +1,116 @@ +# Copyright (c) 2023, Tri Dao. + +import math +import re +from collections import OrderedDict + +import torch +import torch.nn.functional as F +from transformers import GPT2Config, OPTConfig + + +def remap_state_dict_hf_opt(state_dict, config): + def key_mapping_model(key): + key = re.sub(r"^model.decoder.", "transformer.", key) + # The OPT-350m model uses '^decoder' instead of '^model.decoder' + key = re.sub(r"^decoder.", "transformer.", key) + return key + + state_dict = OrderedDict((key_mapping_model(k), v) for k, v in state_dict.items()) + # Word embedding and position embedding + def key_mapping_emb(key): + key = re.sub(r"^transformer.embed_tokens.", "transformer.embeddings.word_embeddings.", key) + # The OPT-350m model uses has project_in and project_out + key = re.sub(r"^transformer.project_in.", "transformer.embeddings.project_in.", key) + key = re.sub(r"^transformer.project_out.", "project_out.", key) + key = re.sub( + r"^transformer.embed_positions.", "transformer.embeddings.position_embeddings.", key + ) + return key + + state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items()) + # OPT uses the first 2 indices of pos_emb for padding tokens + pos_embeddings = state_dict.pop("transformer.embeddings.position_embeddings.weight") + state_dict["transformer.embeddings.position_embeddings.weight"] = pos_embeddings[2:] + word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight") + # It's possible that vocab_size is padded to be a multiple of 8, for example. + pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) + vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple + state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad( + word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) + ) + state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"] + + # LayerNorm + def key_mapping_ln(key): + key = re.sub(r"^transformer.final_layer_norm.", r"transformer.ln_f.", key) + # The OPT-175B checkpoint calls this 'decoder.layer_norm' instead of 'decoder.final_layer_norm' + key = re.sub(r"^transformer.layer_norm.", r"transformer.ln_f.", key) + key = re.sub( + r"^transformer.layers.(\d+).self_attn_layer_norm.", r"transformer.layers.\1.norm1.", key + ) + key = re.sub( + r"^transformer.layers.(\d+).final_layer_norm.", r"transformer.layers.\1.norm2.", key + ) + return key + + state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) + + # MLP + def key_mapping_mlp(key): + return re.sub( + r"^transformer.layers.(\d+).fc(1|2).", r"transformer.layers.\1.mlp.fc\2.", key + ) + + state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) + + # Attention + for l in range(config.n_layer): + Wq = state_dict.pop(f"transformer.layers.{l}.self_attn.q_proj.weight") + Wk = state_dict.pop(f"transformer.layers.{l}.self_attn.k_proj.weight") + Wv = state_dict.pop(f"transformer.layers.{l}.self_attn.v_proj.weight") + bq = state_dict.pop(f"transformer.layers.{l}.self_attn.q_proj.bias") + bk = state_dict.pop(f"transformer.layers.{l}.self_attn.k_proj.bias") + bv = state_dict.pop(f"transformer.layers.{l}.self_attn.v_proj.bias") + state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat([Wq, Wk, Wv], dim=0) + state_dict[f"transformer.layers.{l}.mixer.Wqkv.bias"] = torch.cat([bq, bk, bv], dim=0) + + def key_mapping_attn(key): + return re.sub( + r"^transformer.layers.(\d+).self_attn.out_proj.", + r"transformer.layers.\1.mixer.out_proj.", + key, + ) + + state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) + + return state_dict + + +def opt_config_to_gpt2_config(opt_config: OPTConfig) -> GPT2Config: + assert opt_config.layerdrop == 0.0 + assert opt_config.layer_norm_elementwise_affine + word_embed_proj_dim = ( + None + if opt_config.word_embed_proj_dim == opt_config.hidden_size + else opt_config.word_embed_proj_dim + ) + return GPT2Config( + vocab_size=opt_config.vocab_size, + n_positions=opt_config.max_position_embeddings, + n_embd=opt_config.hidden_size, + n_layer=opt_config.num_hidden_layers, + n_head=opt_config.num_attention_heads, + n_inner=opt_config.ffn_dim, + activation_function=opt_config.activation_function, + resid_pdrop=opt_config.dropout, + # HF's implementation of OPT doesn't seem to have embedding dropout + embd_pdrop=opt_config.dropout, + attn_pdrop=opt_config.attention_dropout, + initializer_range=opt_config.init_std, + bos_token_id=opt_config.bos_token_id, + eos_token_id=opt_config.eos_token_id, + # These are new arguments not in the original GPT2Config + prenorm=opt_config.do_layer_norm_before, + word_embed_proj_dim=word_embed_proj_dim, + ) diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/vit.py b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..4602fd7414d251e40f9d42250c23cc974d596661 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/vit.py @@ -0,0 +1,373 @@ +# Copyright (c) 2022, Tri Dao. +# Inspired by / adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py +import math +import re +from collections import OrderedDict +from copy import deepcopy +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from timm.models.helpers import named_apply +from torch.nn.init import trunc_normal_ +from torchvision.ops import StochasticDepth + +from flash_attn.layers.patch_embed import PatchEmbed +from flash_attn.modules.block import Block +from flash_attn.modules.mha import MHA +from flash_attn.modules.mlp import FusedMLP, Mlp + +try: + from flash_attn.ops.triton.layer_norm import layer_norm_fn +except ImportError: + layer_norm_fn = None + + +def create_mixer_cls( + num_heads, qkv_bias, attn_drop, use_flash_attn, fused_bias_fc, cross_attn=False +): + mixer_cls = partial( + MHA, + num_heads=num_heads, + cross_attn=cross_attn, + qkv_proj_bias=qkv_bias, + dropout=attn_drop, + fused_bias_fc=fused_bias_fc, + use_flash_attn=use_flash_attn, + ) + return mixer_cls + + +def create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_mlp): + inner_dim = int(embed_dim * mlp_ratio) + if not fused_mlp: + mlp_cls = partial(Mlp, hidden_features=inner_dim, activation=act_layer()) + else: + mlp_cls = partial(FusedMLP, hidden_features=inner_dim) + return mlp_cls + + +def create_block( + embed_dim, + num_heads, + mlp_ratio, + qkv_bias, + drop_rate, + attn_drop_rate, + drop_path1, + drop_path2, + norm_layer, + act_layer, + use_flash_attn, + fused_bias_fc, + fused_mlp, + fused_dropout_add_ln, + layer_idx=None, + n_layer=None, + last_layer_subset=False, +): + mixer_cls = create_mixer_cls( + num_heads, + qkv_bias, + attn_drop_rate, + use_flash_attn, + fused_bias_fc, + cross_attn=(last_layer_subset and layer_idx == n_layer - 1), + ) + mlp_cls = create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_mlp) + # TD [2022-10-15]: Force residual in fp32 in case of DeepSpeed + block = Block( + embed_dim, + mixer_cls, + mlp_cls, + norm_cls=norm_layer, + prenorm=True, + resid_dropout1=drop_rate, + resid_dropout2=drop_rate, + drop_path1=drop_path1, + drop_path2=drop_path2, + fused_dropout_add_ln=fused_dropout_add_ln, + residual_in_fp32=True, + ) + return block + + +class VisionTransformer(nn.Module): + """Vision Transformer + A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` + - https://arxiv.org/abs/2010.11929 + """ + + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + num_classes=1000, + global_pool="token", + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + init_values=None, + class_token=True, + no_embed_class=False, + pre_norm=False, + fc_norm=None, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.0, + weight_init="", + embed_layer=PatchEmbed, + norm_layer=None, + act_layer=None, + use_flash_attn=False, + fused_bias_fc=False, + fused_mlp=False, + fused_dropout_add_ln=False, + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + num_classes (int): number of classes for classification head + global_pool (str): type of global pooling for final sequence (default: 'token') + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + init_values: (float): layer-scale init values + class_token (bool): use class token + fc_norm (Optional[bool]): pre-fc norm after pool, set if global_pool == 'avg' if None (default: None) + drop_rate (float): dropout rate + attn_drop_rate (float): attention dropout rate + drop_path_rate (float): stochastic depth rate + weight_init (str): weight init scheme + embed_layer (nn.Module): patch embedding layer + norm_layer: (nn.Module): normalization layer + act_layer: (nn.Module): MLP activation layer + """ + super().__init__() + assert global_pool == "token", "Only support pooling with CLS token" + assert class_token + assert init_values is None, "LayerScale is not supported yet" + assert weight_init == "" + assert fc_norm is None + # pre_norm seems redundant, as there's a LayerNorm right at the start of each block, idk + assert not pre_norm + use_fc_norm = global_pool == "avg" if fc_norm is None else fc_norm + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + act_layer = act_layer or nn.GELU + + self.num_classes = num_classes + self.global_pool = global_pool + self.num_features = ( + self.embed_dim + ) = embed_dim # num_features for consistency with other models + self.num_prefix_tokens = 1 if class_token else 0 + self.no_embed_class = no_embed_class + + patch_embed_extra_kwargs = ( + {"fused_bias_fc": fused_bias_fc} if embed_layer is PatchEmbed else {} + ) + self.patch_embed = embed_layer( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP) + **patch_embed_extra_kwargs, + ) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None + embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens + self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * 0.02) + + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, depth) + ] # stochastic depth decay rule + + # We change the order of dropout, residual and layer norm: + # Instead of LN -> Attn / MLP -> Dropout -> Add, we do: + # Dropout -> Add -> LN -> Attn / MLP, returning both the residual branch (output of Add) and + # the main branch (output of MLP). The model definition is unchanged, but the mapping of the + # nn.Dropout probabilities are changed. + # This is for performance reason: we can fuse dropout + add + layer_norm. + self.blocks = nn.ModuleList( + [ + create_block( + embed_dim, + num_heads, + mlp_ratio, + qkv_bias, + drop_rate, + attn_drop_rate, + drop_path1=dpr[i - 1] if i > 0 else 0.0, + drop_path2=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + use_flash_attn=use_flash_attn, + fused_bias_fc=fused_bias_fc, + fused_mlp=fused_mlp, + fused_dropout_add_ln=fused_dropout_add_ln, + layer_idx=i, + n_layer=depth, + last_layer_subset=(global_pool == "token"), + ) + for i in range(depth) + ] + ) + + self.dropout = nn.Dropout(p=drop_rate) + self.drop_path = StochasticDepth(p=dpr[-1], mode="row") + self.norm = norm_layer(embed_dim) + + self.fused_dropout_add_ln = fused_dropout_add_ln + if self.fused_dropout_add_ln and layer_norm_fn is None: + raise ImportError("Triton is not installed") + + # Classifier Head + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + self.init_weights(weight_init) + + def init_weights(self, mode=""): + assert mode == "" + trunc_normal_(self.pos_embed, std=0.02) + if self.cls_token is not None: + nn.init.normal_(self.cls_token, std=1e-6) + named_apply(init_weights_vit_timm, self) + + def _init_weights(self, m): + # this fn left here for compat with downstream users + init_weights_vit_timm(m) + + @torch.jit.ignore + def no_weight_decay(self): + return {"pos_embed", "cls_token"} + + def _pos_embed(self, x): + if self.no_embed_class: + # deit-3, updated JAX (big vision) + # position embedding does not overlap with class token, add then concat + x = x + self.pos_embed + if self.cls_token is not None: + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + else: + # original timm, JAX, and deit vit impl + # pos_embed has entry for class token, concat then add + if self.cls_token is not None: + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.pos_embed + return x + + def forward_features(self, x, all_tokens=True): + """ + If all_tokens==False and self.global_pool == 'token', we only return the features for the + cls token. + """ + x = self.patch_embed(x) + hidden_states = self._pos_embed(x) + residual = None + if self.global_pool != "token" or all_tokens: + # if True: + for block in self.blocks: + hidden_states, residual = block(hidden_states, residual) + else: + for block in self.blocks[:-1]: + hidden_states, residual = block(hidden_states, residual) + # For the last layer, we only want the 1st token of the output. So we do cross-attention + # where the query is the 1st token and the key/value is the whole sequence. + hidden_states, residual = self.blocks[-1]( + hidden_states, residual, mixer_subset=slice(0, 1) + ) + if not self.fused_dropout_add_ln: + residual = self.drop_path(self.dropout(hidden_states)) + residual + hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype)) + else: + if self.drop_path.p == 0 or not self.training: + rowscale = None + else: + rowscale = self.drop_path( + torch.ones( + hidden_states.shape[:-1], + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + ) + # Set prenorm=False here since we don't need to the residual + hidden_states = layer_norm_fn( + hidden_states, + self.norm.weight, + self.norm.bias, + residual=residual, + eps=self.norm.eps, + dropout_p=self.dropout.p if self.training else 0.0, + rowscale=rowscale, + prenorm=False, + ) + return hidden_states + + def forward_head(self, x, pre_logits: bool = False): + if self.global_pool: + x = x[:, self.num_prefix_tokens :].mean(dim=1) if self.global_pool == "avg" else x[:, 0] + return x if pre_logits else self.head(x) + + def forward(self, x): + x = self.forward_features(x, all_tokens=False) + x = self.forward_head(x) + return x + + def load_state_dict(self, state_dict, strict=True): + patch_embed_weight = state_dict["patch_embed.proj.weight"] + if patch_embed_weight.dim() == 4: + # convert from Conv2d to Linear + state_dict["patch_embed.proj.weight"] = rearrange( + patch_embed_weight, "o c h w -> o (c h w)" + ) + + def key_mapping_attn(key): + key = re.sub(r"^blocks.(\d+).attn.qkv.", r"blocks.\1.mixer.Wqkv.", key) + key = re.sub(r"^blocks.(\d+).attn.proj.", r"blocks.\1.mixer.out_proj.", key) + return key + + state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) + n_layer = len(self.blocks) + # Convert from Wqkv to Wq and Wkv for cross attention (last layer) + if ( + self.blocks[-1].mixer.cross_attn + and f"blocks.{n_layer - 1}.mixer.Wqkv.weight" in state_dict + ): + Wqkv = state_dict.pop(f"blocks.{n_layer - 1}.mixer.Wqkv.weight") + bqkv = state_dict.pop(f"blocks.{n_layer - 1}.mixer.Wqkv.bias") + state_dict[f"blocks.{n_layer - 1}.mixer.Wq.weight"] = Wqkv[: self.embed_dim] + state_dict[f"blocks.{n_layer - 1}.mixer.Wkv.weight"] = Wqkv[self.embed_dim :] + state_dict[f"blocks.{n_layer - 1}.mixer.Wq.bias"] = bqkv[: self.embed_dim] + state_dict[f"blocks.{n_layer - 1}.mixer.Wkv.bias"] = bqkv[self.embed_dim :] + return super().load_state_dict(state_dict, strict=strict) + + +def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif hasattr(module, "init_weights"): + module.init_weights() + + +def vit_base_patch16_224(pretrained=False, **kwargs): + """ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. + """ + assert not pretrained + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = VisionTransformer(**model_kwargs) + return model diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/__init__.py b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..15a7870cd32fae4011aa6979165987e237762b1f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/__pycache__/activations.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/__pycache__/activations.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..877ea3b3b237b209f941123f3ae46d8ca5aaf093 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/__pycache__/activations.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/__pycache__/fused_dense.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/__pycache__/fused_dense.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3bf34e3cf2558506e561ad9070e44db768ab6424 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/__pycache__/fused_dense.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/__pycache__/layer_norm.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/__pycache__/layer_norm.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..67455137d70a15006d7193ff11a2b595a7f0824d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/__pycache__/layer_norm.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/__pycache__/rms_norm.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/__pycache__/rms_norm.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e61f1ec925bf858128102f463ecfb0daa22913f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/__pycache__/rms_norm.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/activations.py b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/activations.py new file mode 100644 index 0000000000000000000000000000000000000000..b00063b6bd497e10a70a201cfe246178174aad67 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/activations.py @@ -0,0 +1,135 @@ +# Copied from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/model/layers/activations.py +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +# 1/sqrt(2*pi)-> 0.3989423 +# 1/sqrt(2) -> 0.70710678 +# sqrt(2/pi) -> 0.79788456 + +# this function is tanh approximation of gelu +# actual gelu is: +# x * 0.5 * (1.0 + torch.erf(x * 0.70710678)) +@torch.jit.script +def bias_gelu(y, bias): + x = bias + y + return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=y.dtype) + + +# gradient of tanh approximation of gelu +# gradient of actual gelu is: +# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x) +@torch.jit.script +def bias_gelu_back(g, y, bias): + """Assume that y has shape (B, D) and bias has shape (D)""" + x = bias + y + tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) + # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 + ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * ( + 1 + tanh_out + ) + grad_y = ff * g + return grad_y.to(dtype=y.dtype), grad_y.sum(dim=(0), dtype=bias.dtype) + + +class GeLUFunction(torch.autograd.Function): + @staticmethod + # bias is an optional argument + def forward(ctx, input, bias): + ctx.save_for_backward(input, bias) + return bias_gelu(input, bias) + + @staticmethod + def backward(ctx, grad_output): + input, bias = ctx.saved_tensors + tmp = bias_gelu_back(grad_output, input, bias) + return tmp, tmp + + +bias_gelu_impl = GeLUFunction.apply + +# this function is tanh approximation of gelu +# actual gelu is: +# x * 0.5 * (1.0 + torch.erf(x * 0.70710678)) +@torch.jit.script +def gelu_fwd(x): + return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=x.dtype) + + +# gradient of tanh approximation of gelu +# gradient of actual gelu is: +# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x) +@torch.jit.script +def gelu_bwd(g, x): + tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) + # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 + ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * ( + 1 + tanh_out + ) + return (ff * g).to(dtype=x.dtype) + + +class FastGeLUFunction(torch.autograd.Function): + @staticmethod + # bias is an optional argument + def forward(ctx, input): + ctx.save_for_backward(input) + return gelu_fwd(input) + + @staticmethod + def backward(ctx, grad_output): + (input,) = ctx.saved_tensors + tmp = gelu_bwd(grad_output, input) + return tmp + + +fast_gelu_impl = FastGeLUFunction.apply + + +@torch.jit.script +def relu_bwd(g, x): + return torch.where(x >= 0, g, 0.0).to(dtype=x.dtype) + + +@torch.jit.script +def sqrelu_fwd(x): + r = F.relu(x) + return (r * r).to(dtype=x.dtype) + + +@torch.jit.script +def sqrelu_bwd(g, x): + return (2.0 * g * F.relu(x)).to(dtype=x.dtype) + + +swiglu_fwd_codestring = """ +template T swiglu_fwd(T x, T y) { + return float(x) * float(y) / (1.0f + ::exp(-float(x))); +} +""" +swiglu_bwd_codestring = """ +template T swiglu_bwd(T x, T y, T g, T& dx, T& dy) { + float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x))); + dx = x_sigmoid * (1 + float(x) * (1.0f - x_sigmoid)) * float(g) * float(y); + dy = float(x) * x_sigmoid * float(g); +} +""" +swiglu_fwd = torch.cuda.jiterator._create_jit_fn(swiglu_fwd_codestring) +swiglu_bwd = torch.cuda.jiterator._create_multi_output_jit_fn(swiglu_bwd_codestring, num_outputs=2) + + +class SwiGLUFunction(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, y): + ctx.save_for_backward(x, y) + return swiglu_fwd(x, y) + + @staticmethod + def backward(ctx, dout): + x, y = ctx.saved_tensors + return swiglu_bwd(x, y, dout) + +swiglu = SwiGLUFunction.apply diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/fused_dense.py b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/fused_dense.py new file mode 100644 index 0000000000000000000000000000000000000000..1e45b8e609812a1545781011141ec80f6dc3af0f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/fused_dense.py @@ -0,0 +1,688 @@ +# Copyright (c) 2023, Tri Dao. +# Inspired by https://github.com/NVIDIA/apex/blob/master/apex/fused_dense/fused_dense.py +# We make it work with pytorch amp and with bfloat16. +# The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py +from functools import partial +from typing import Optional + +# import fused_dense_cuda # from apex +import fused_dense_lib as fused_dense_cuda +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.cuda.amp import custom_bwd, custom_fwd +from torch.distributed import ProcessGroup + +from flash_attn.ops.activations import gelu_bwd, relu_bwd, sqrelu_bwd, sqrelu_fwd +from flash_attn.utils.distributed import ( + all_gather_raw, + all_reduce, + all_reduce_raw, + reduce_scatter, + reduce_scatter_raw, +) + + +class FusedDenseFunc(torch.autograd.Function): + @staticmethod + @custom_fwd + def forward( + ctx, x, weight, bias, return_residual=False, process_group=None, sequence_parallel=True + ): + """ + If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel + with sequence parallelism: we do an all_gather_raw of x before doing the matmul. + """ + ctx.compute_weight_gradient = weight.requires_grad + ctx.return_residual = return_residual + ctx.process_group = process_group + ctx.sequence_parallel = sequence_parallel + + if torch.is_autocast_enabled(): + x = x.to(dtype=torch.get_autocast_gpu_dtype()) + x = x.contiguous() + if process_group is not None and sequence_parallel: + # We want to kick off the all_gather early, before weight dtype conversion + total_x, handle_x = all_gather_raw(x, process_group, async_op=True) + else: + total_x = x + + if torch.is_autocast_enabled(): + weight = weight.to(dtype=torch.get_autocast_gpu_dtype()) + bias = bias.to(dtype=torch.get_autocast_gpu_dtype()) if bias is not None else None + weight = weight.contiguous() + if process_group is not None and sequence_parallel: + handle_x.wait() + batch_shape, n = total_x.shape[:-1], total_x.shape[-1] + batch_dim = batch_shape.numel() + # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174 + if min(batch_dim, n, *weight.shape) > 65535 * 32: + raise RuntimeError("fused_dense only supports matrix dims <= 2M") + output = F.linear(total_x, weight, bias) + if ctx.compute_weight_gradient: + ctx.save_for_backward(x, weight) + else: + ctx.save_for_backward(weight) + return output if not return_residual else (output, x) + + @staticmethod + @custom_bwd + def backward(ctx, grad_output, *args): + grad_output = grad_output.contiguous() + if ctx.return_residual: + (grad_input,) = args + grad_input = grad_input.contiguous() + process_group = ctx.process_group + sequence_parallel = ctx.sequence_parallel + if ctx.compute_weight_gradient: + x, weight = ctx.saved_tensors + if process_group is not None and sequence_parallel: + total_x, handle_x = all_gather_raw(x, process_group, async_op=True) + else: + total_x = x + else: + (weight,) = ctx.saved_tensors + total_x = None + batch_shape = grad_output.shape[:-1] + batch_dim = batch_shape.numel() + grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) + if ctx.needs_input_grad[0]: + if not ctx.return_residual: + grad_input = F.linear(grad_output, weight.t()) + else: + grad_input = torch.addmm( + grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_output, weight + ) + grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1]) + if process_group is not None: + reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw + grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True) + else: + grad_input = None + if ctx.needs_input_grad[1]: + assert ctx.compute_weight_gradient + if process_group is not None and sequence_parallel: + handle_x.wait() + grad_weight, grad_bias = fused_dense_cuda.linear_bias_wgrad( + total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2] + ) + else: + grad_weight = None + grad_bias = grad_output if ctx.needs_input_grad[2] else None + if process_group is not None and ctx.needs_input_grad[0]: + handle_grad_input.wait() + return grad_input, grad_weight, grad_bias, None, None, None + + +def fused_dense_func( + x: Tensor, + weight: Tensor, + bias: Optional[Tensor] = None, + return_residual: bool = False, + process_group: Optional[ProcessGroup] = None, + sequence_parallel: bool = True, +): + dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or ( + x.dtype == torch.float32 and torch.is_autocast_enabled() + ) + if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible: + return FusedDenseFunc.apply( + x, weight, bias, return_residual, process_group, sequence_parallel + ) + else: + assert process_group is None + out = F.linear(x, weight, bias) + return out if not return_residual else (out, x) + + +class FusedDense(nn.Linear): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + return_residual: bool = False, + device=None, + dtype=None, + ) -> None: + super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype) + self.return_residual = return_residual + + def forward(self, x, process_group=None): + """ + If process_group is not None, we're doing Tensor Parallel with sequence parallelism: + we do an all_gather of x before doing the matmul. + """ + return fused_dense_func( + x, + self.weight, + self.bias, + return_residual=self.return_residual, + process_group=process_group, + ) + + +class ColumnParallelLinear(nn.Linear): + def __init__( + self, + in_features: int, + out_features: int, + process_group: ProcessGroup, + bias: bool = True, + sequence_parallel=True, + multiple_of=1, + device=None, + dtype=None, + ) -> None: + world_size = torch.distributed.get_world_size(process_group) + if out_features % multiple_of: + raise ValueError(f"out_features ({out_features}) must be a multiple of {multiple_of}") + multiple = out_features // multiple_of + # We want to split @multiple across world_size, but it could be an uneven split + div = multiple // world_size + mod = multiple % world_size + # The first @mod ranks get @div + 1 copies, the rest get @div copies + local_multiple = div + int(torch.distributed.get_rank(process_group) < mod) + super().__init__( + in_features, local_multiple * multiple_of, bias=bias, device=device, dtype=dtype + ) + self.process_group = process_group + self.sequence_parallel = sequence_parallel + + def forward(self, x): + # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: + # we do an all_gather of x before doing the matmul. + # If not, then the input is already gathered. + return fused_dense_func( + x, + self.weight, + self.bias, + process_group=self.process_group, + sequence_parallel=self.sequence_parallel, + ) + + +class RowParallelLinear(nn.Linear): + def __init__( + self, + in_features: int, + out_features: int, + process_group: ProcessGroup, + bias: bool = True, + sequence_parallel=True, + multiple_of=1, + device=None, + dtype=None, + ) -> None: + world_size = torch.distributed.get_world_size(process_group) + rank = torch.distributed.get_rank(process_group) + if in_features % multiple_of: + raise ValueError(f"in_features ({in_features}) must be a multiple of {multiple_of}") + multiple = in_features // multiple_of + # We want to split @multiple across world_size, but it could be an uneven split + div = multiple // world_size + mod = multiple % world_size + # The first @mod ranks get @div + 1 copies, the rest get @div copies + local_multiple = div + int(torch.distributed.get_rank(process_group) < mod) + # Only rank 0 will have bias + super().__init__( + local_multiple * multiple_of, + out_features, + bias=bias and rank == 0, + device=device, + dtype=dtype, + ) + self.process_group = process_group + self.sequence_parallel = sequence_parallel + + def forward(self, x): + """ + We're doing Tensor Parallel with sequence parallelism: we do the matmul and then + a reduce_scatter of the result. + """ + out = fused_dense_func(x, self.weight, self.bias) + reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce + return reduce_fn(out, self.process_group) + + +class FusedMLPFunc(torch.autograd.Function): + @staticmethod + @custom_fwd + def forward( + ctx, + x, + weight1, + bias1, + weight2, + bias2, + activation="gelu_approx", + save_pre_act=True, + return_residual=False, + checkpoint_lvl=0, + heuristic=0, + process_group=None, + sequence_parallel=True, + ): + """ + If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel + with sequence parallelism: we do an all_gather of x before doing the matmul. + If sequence_parallel=False, then the input is already gathered. + + checkpoint_lvl: + 0: no recomputation in the bwd + 1: recompute gelu_out / relu_out in the bwd + 2: recompute pre_act and gelu_out / relu_out in the bwd + """ + assert -1 <= heuristic <= 4 + assert activation in ["gelu_approx", "relu", "sqrelu"] + if activation == "sqrelu": + assert heuristic == -1 + if not save_pre_act: + checkpoint_lvl = 2 + assert checkpoint_lvl in [0, 1, 2] + ctx.return_residual = return_residual + ctx.process_group = process_group + ctx.sequence_parallel = sequence_parallel + ctx.checkpoint_lvl = checkpoint_lvl + ctx.activation = activation + ctx.heuristic = heuristic + + if torch.is_autocast_enabled(): + x = x.to(dtype=torch.get_autocast_gpu_dtype()) + x = x.contiguous() + if process_group is not None and sequence_parallel: + # We want to kick off the all_gather early, before weight dtype conversion + total_x, handle_x = all_gather_raw(x, process_group, async_op=True) + else: + total_x = x + + if torch.is_autocast_enabled(): + dtype = torch.get_autocast_gpu_dtype() + weight1, weight2 = [a.to(dtype=dtype) for a in [weight1, weight2]] + bias1 = bias1.to(dtype=dtype) if bias1 is not None else None + bias2 = bias2.to(dtype=dtype) if bias2 is not None else None + weight1 = weight1.contiguous() + bias1 = bias1.contiguous() if bias1 is not None else None + weight2 = weight2.contiguous() + bias2 = bias2.contiguous() if bias2 is not None else None + if process_group is not None and sequence_parallel: + handle_x.wait() + batch_shape, n = total_x.shape[:-1], total_x.shape[-1] + batch_dim = batch_shape.numel() + # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174 + if min(batch_dim, n, *weight1.shape, *weight2.shape) > 65535 * 32: + raise RuntimeError("fused_dense only supports matrix dims <= 2M") + if heuristic == -1: + pre_act = F.linear(total_x, weight1, bias1) + activation_fn = ( + partial(F.gelu, approximate="tanh") + if activation == "gelu_approx" + else (sqrelu_fwd if activation == "sqrelu" else F.relu) + ) + with torch.jit.fuser("fuser2"): + output1 = activation_fn(pre_act) + # This is before adding bias1 + # pre_act = F.linear(total_x.reshape(batch_dim, n), weight1) + # with torch.jit.fuser('fuser2'): + # output1 = bias_gelu(pre_act, bias1) + else: + is_gelu = activation == "gelu_approx" + output1, *rest = fused_dense_cuda.linear_act_forward( + total_x.reshape(batch_dim, n), weight1, bias1, is_gelu, save_pre_act, heuristic + ) + if save_pre_act: + pre_act = rest[0] + output2 = F.linear(output1, weight2, bias2) + if checkpoint_lvl == 0 or (checkpoint_lvl == 1 and activation == "relu"): + # For RELU the pre_act is very small (just a bit-mask) so we just save it + ctx.save_for_backward(x, weight1, weight2, pre_act, output1) + elif checkpoint_lvl == 1: + ctx.save_for_backward(x, weight1, weight2, pre_act) + elif checkpoint_lvl == 2: + ctx.save_for_backward(x, weight1, weight2, bias1) + output2 = output2.reshape(*batch_shape, output2.shape[-1]) + return output2 if not return_residual else (output2, x) + + @staticmethod + @custom_bwd + def backward(ctx, grad_output, *args): + grad_output = grad_output.contiguous() + checkpoint_lvl = ctx.checkpoint_lvl + activation = ctx.activation + activation_fn = ( + partial(F.gelu, approximate="tanh") + if activation == "gelu_approx" + else (sqrelu_fwd if activation == "sqrelu" else F.relu) + ) + if ctx.return_residual: + (grad_input,) = args + grad_input = grad_input.contiguous() + process_group = ctx.process_group + sequence_parallel = ctx.sequence_parallel + x, weight1, weight2, *rest = ctx.saved_tensors + if process_group is None or not sequence_parallel: + total_x = x + batch_shape = grad_output.shape[:-1] + batch_dim = batch_shape.numel() + if checkpoint_lvl in [0, 1]: + if process_group is not None and sequence_parallel: + total_x, handle_x = all_gather_raw(x, process_group, async_op=True) + if checkpoint_lvl == 0 or (checkpoint_lvl == 1 and activation == "relu"): + pre_act, output1 = rest + elif checkpoint_lvl == 1: + (pre_act,) = rest + with torch.jit.fuser("fuser2"): + output1 = activation_fn(pre_act) + elif checkpoint_lvl == 2: + (bias1,) = rest + if process_group is not None and sequence_parallel: + total_x, _ = all_gather_raw(x, process_group) + if ctx.heuristic == -1: + pre_act = F.linear(total_x, weight1, bias1) + with torch.jit.fuser("fuser2"): + output1 = activation_fn(pre_act) + else: + output1, pre_act = fused_dense_cuda.linear_act_forward( + total_x.reshape(batch_dim, total_x.shape[-1]), + weight1, + bias1, + activation == "gelu_approx", + True, + ctx.heuristic, + ) + + grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) + output1 = output1.reshape(batch_dim, output1.shape[-1]) + pre_act = pre_act.reshape(batch_dim, pre_act.shape[-1]) + if ctx.needs_input_grad[3]: + grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad( + output1, grad_output, ctx.needs_input_grad[4] + ) + else: + grad_weight2 = None + grad_bias2 = grad_output if ctx.needs_input_grad[4] else None + if ctx.heuristic == -1: + # grad_pre_act = matmul_dgelu(grad_output, weight2, pre_act) + grad_output1 = F.linear(grad_output, weight2.t()) + activation_grad_fn = ( + gelu_bwd + if activation == "gelu_approx" + else (sqrelu_bwd if activation == "sqrelu" else relu_bwd) + ) + with torch.jit.fuser("fuser2"): + grad_pre_act = activation_grad_fn(grad_output1, pre_act) + else: + # The cublasLt epilogue has to compute both gelu/relu grad and bias grad, we can't + # just compute gelu/relu grad + grad_pre_act, grad_bias1 = fused_dense_cuda.bias_act_linear_dgrad_bgrad( + weight2, grad_output, pre_act, activation == "gelu_approx", ctx.heuristic + ) + if not ctx.needs_input_grad[2]: + grad_bias1 = None + if ctx.needs_input_grad[0]: + if not ctx.return_residual: + grad_input = F.linear(grad_pre_act, weight1.t()) + else: + grad_input = torch.addmm( + grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_pre_act, weight1 + ) + grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1]) + if process_group is not None: + reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw + grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True) + else: + grad_input = None + if ctx.heuristic == -1: + if ctx.needs_input_grad[1]: + if process_group is not None and sequence_parallel and checkpoint_lvl != 2: + handle_x.wait() + grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_wgrad( + total_x.reshape(batch_dim, total_x.shape[-1]), + grad_pre_act, + ctx.needs_input_grad[2], + ) + else: + grad_weight1 = None + grad_bias1 = grad_pre_act if ctx.needs_input_grad[2] else None + else: + if ctx.needs_input_grad[1]: + if process_group is not None and sequence_parallel and checkpoint_lvl != 2: + handle_x.wait() + grad_weight1 = F.linear( + grad_pre_act.t(), total_x.reshape(batch_dim, total_x.shape[-1]).t() + ) + else: + grad_weight1 = None + if process_group is not None and ctx.needs_input_grad[0]: + handle_grad_input.wait() + return ( + grad_input, + grad_weight1, + grad_bias1, + grad_weight2, + grad_bias2, + None, + None, + None, + None, + None, + None, + None, + ) + + +def fused_mlp_func( + x: Tensor, + weight1: Tensor, + weight2: Tensor, + bias1: Optional[Tensor] = None, + bias2: Optional[Tensor] = None, + activation: str = "gelu_approx", + save_pre_act: bool = True, + return_residual: bool = False, + checkpoint_lvl: int = 0, + heuristic: int = 0, + process_group: Optional[ProcessGroup] = None, + sequence_parallel: bool = True, +): + assert activation in ["gelu_approx", "relu", "sqrelu"] + dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or ( + x.dtype == torch.float32 and torch.is_autocast_enabled() + ) + # If we save pre-activation, dimension must be divisible by 128 (relu) or 8 (gelu) + dim_eligible = not save_pre_act or (x.shape[-1] % (128 if activation == "relu" else 8) == 0) + if ( + x.is_cuda + and weight1.is_cuda + and weight2.is_cuda + and (bias1 is None or bias1.is_cuda) + and (bias2 is None or bias2.is_cuda) + and dtype_eligible + and dim_eligible + ): + return FusedMLPFunc.apply( + x, + weight1, + bias1, + weight2, + bias2, + activation, + save_pre_act, + return_residual, + checkpoint_lvl, + heuristic, + process_group, + sequence_parallel, + ) + else: + assert process_group is None + pre_act = F.linear(x, weight1, bias1) + activation_fn = ( + partial(F.gelu, approximate="tanh") + if activation == "gelu_approx" + else partial(F.relu, inplace=True) + ) + output1 = activation_fn(pre_act) + output2 = F.linear(output1, weight2, bias2) + return output2 if not return_residual else (output2, x) + + +class FusedMLP(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + bias1=True, + bias2=True, + activation="gelu_approx", + return_residual=False, + checkpoint_lvl=0, + heuristic="auto", + device=None, + dtype=None, + ): + """ + If process_group is not None, we're doing Tensor Parallel with sequence parallelism: + we do an all_gather of x before doing the matmul, gelu, then matmul. + Finally we do a reduce_scatter of the output. + + checkpoint_lvl (increasing lvl means slower but more memory saving): + 0: no recomputation in the bwd + 1: recompute gelu_out in the bwd + 2: recompute pre_act and gelu_out in the bwd + heuristic: + -1: don't fuse gemm + gelu (separate kernel) + 0..4: use this heuristic for the algo section in the fused gemm + gelu + 'auto': heuristic will be picked automatically: + For CUDA >= 11.8, we set heuristic=0 for both fp16 and bf16 for best perf. + For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16. + For H100, we set heuristic=-1 for both fp16 and bf16 as the fused cuBlasLt implementation + is slower than the unfused version. + return_residual: whether to return the input x along with the output. This is for + performance reason: for post-norm architecture, returning the input allows us + to fuse the backward of nn.Linear with the residual connection. + """ + assert checkpoint_lvl in [0, 1, 2] + assert activation in ["gelu_approx", "relu", "sqrelu"] + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features * 4 + self.activation = activation + self.return_residual = return_residual + self.checkpoint_lvl = checkpoint_lvl + self.heuristic = heuristic if activation != "sqrelu" else -1 + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs) + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs) + + def forward(self, x, process_group=None): + dtype = x.dtype if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype() + if self.heuristic == "auto": + if self.activation == "gelu_approx": + if torch.cuda.get_device_capability("cuda") == (9, 0): + heuristic = -1 + else: + cuda_ver = tuple(map(int, torch.version.cuda.split("."))) + heuristic = 0 if cuda_ver >= (11, 8) else (1 if dtype == torch.float16 else -1) + else: + heuristic = 0 + else: + heuristic = self.heuristic + out = fused_mlp_func( + x, + self.fc1.weight, + self.fc2.weight, + self.fc1.bias, + self.fc2.bias, + activation=self.activation, + save_pre_act=self.training, + return_residual=self.return_residual, + checkpoint_lvl=self.checkpoint_lvl, + heuristic=heuristic, + process_group=process_group, + ) + if self.return_residual: + out, x = out + if process_group is not None: + out = reduce_scatter(out, process_group) + return out if not self.return_residual else (out, x) + + +class ParallelFusedMLP(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + activation="gelu_approx", + process_group: ProcessGroup = None, + bias1=True, + bias2=True, + sequence_parallel=True, + checkpoint_lvl=0, + heuristic="auto", + device=None, + dtype=None, + ): + """ + process_group is required. We're doing Tensor Parallel with sequence parallelism: + we do an all_gather of x before doing the matmul, gelu, then matmul. + Finally we do a reduce_scatter of the output. + + checkpoint_lvl (increasing lvl means slower but more memory saving): + 0: no recomputation in the bwd + 1: recompute gelu_out in the bwd + 2: recompute pre_act and gelu_out in the bwd + heuristic: + -1: don't fuse gemm + gelu (separate kernel) + 0..4: use this heuristic for the algo section in the fused gemm + gelu + 'auto': heuristic will be picked automatically: + For CUDA >= 11.8, we set heuristic=0 for both fp16 and bf16 for best perf. + For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16. + """ + assert checkpoint_lvl in [0, 1, 2] + assert activation in ["gelu_approx", "relu", "sqrelu"] + assert process_group is not None + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features * 4 + self.activation = activation + self.process_group = process_group + self.sequence_parallel = sequence_parallel + self.checkpoint_lvl = checkpoint_lvl + self.heuristic = heuristic if activation != "sqrelu" else -1 + self.fc1 = ColumnParallelLinear( + in_features, hidden_features, process_group, bias=bias1, **factory_kwargs + ) + self.fc2 = RowParallelLinear( + hidden_features, out_features, process_group, bias=bias2, **factory_kwargs + ) + + def forward(self, x): + dtype = x.dtype if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype() + if self.heuristic == "auto": + if self.activation == "gelu_approx": + cuda_ver = tuple(map(int, torch.version.cuda.split("."))) + heuristic = 0 if cuda_ver >= (11, 8) else (1 if dtype == torch.float16 else -1) + else: + heuristic = 0 + else: + heuristic = self.heuristic + out = fused_mlp_func( + x, + self.fc1.weight, + self.fc2.weight, + self.fc1.bias, + self.fc2.bias, + activation=self.activation, + save_pre_act=self.training, + checkpoint_lvl=self.checkpoint_lvl, + heuristic=heuristic, + process_group=self.process_group, + sequence_parallel=self.sequence_parallel, + ) + reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce + return reduce_fn(out, self.process_group) diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/rms_norm.py b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/rms_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..068348d61290e3839dd082b540d898578ba1e8e2 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/rms_norm.py @@ -0,0 +1,174 @@ +# Copyright (c) 2022, Tri Dao. +# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py + +import torch +from torch.nn import init + +from flash_attn.ops.layer_norm import ( + DropoutAddLayerNormFn, + DropoutAddLayerNormParallelResidualFn, + DropoutAddLayerNormSubsetFn, +) + + +def rms_norm(x, weight, epsilon): + return DropoutAddLayerNormFn.apply( + x, None, weight, None, None, None, 0.0, epsilon, False, False, True + ) + + +def dropout_add_rms_norm( + x0, + residual, + weight, + bias, + dropout_p, + epsilon, + rowscale=None, + layerscale=None, + prenorm=False, + residual_in_fp32=False, + return_dropout_mask=False, +): + """residual_in_fp32 only has an effect if residual is None. + Otherwise residual dtype is residual.dtype. + """ + return DropoutAddLayerNormFn.apply( + x0, + residual, + weight, + bias, + rowscale, + layerscale, + dropout_p, + epsilon, + residual_in_fp32, + prenorm, + True, + return_dropout_mask, + ) + + +def dropout_add_rms_norm_subset( + x0, + residual, + weight, + bias, + dropout_p, + epsilon, + layerscale=None, + x0_subset=None, + out_subset=None, + rowscale_const=1.0, + out_numrows=0, + prenorm=False, + residual_in_fp32=False, + return_dropout_mask=False, +): + """residual_in_fp32 only has an effect if residual is None. + Otherwise residual dtype is residual.dtype. + """ + return DropoutAddLayerNormSubsetFn.apply( + x0, + residual, + weight, + bias, + layerscale, + x0_subset, + out_subset, + dropout_p, + epsilon, + rowscale_const, + out_numrows, + residual_in_fp32, + prenorm, + True, + return_dropout_mask, + ) + + +def dropout_add_rms_norm_parallel_residual( + x0, + x1, + residual, + weight0, + bias0, + weight1, + bias1, + dropout_p, + epsilon, + prenorm=False, + residual_in_fp32=False, + return_dropout_mask=False, +): + """residual_in_fp32 only has an effect if residual is None. + Otherwise residual dtype is residual.dtype. + """ + return DropoutAddLayerNormParallelResidualFn.apply( + x0, + x1, + residual, + weight0, + bias0, + weight1, + bias1, + dropout_p, + epsilon, + residual_in_fp32, + prenorm, + True, + return_dropout_mask, + ) + + +class RMSNorm(torch.nn.Module): + def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.eps = eps + self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self): + init.ones_(self.weight) + + def forward(self, x): + return rms_norm(x, self.weight, self.eps) + + +class DropoutAddRMSNorm(torch.nn.Module): + def __init__( + self, + hidden_size, + prenorm=False, + p=0.0, + eps=1e-5, + residual_in_fp32=False, + device=None, + dtype=None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.prenorm = prenorm + self.p = p + self.eps = eps + self.residual_in_fp32 = residual_in_fp32 + self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self): + init.ones_(self.weight) + + def forward(self, x0, residual=None): + return dropout_add_rms_norm( + x0, + residual, + self.weight, + None, + self.p if self.training else 0.0, + self.eps, + prenorm=self.prenorm, + residual_in_fp32=self.residual_in_fp32, + ) diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/triton/__init__.py b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/triton/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/triton/__init__.py @@ -0,0 +1 @@ + diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/triton/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/triton/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..78cf0bae78801c1a12bfe62b083ecf67f01577e3 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/triton/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/triton/__pycache__/cross_entropy.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/triton/__pycache__/cross_entropy.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..59e266242f4a96489bee30a9f0dc20107a0fa92b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/triton/__pycache__/cross_entropy.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/triton/__pycache__/k_activations.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/triton/__pycache__/k_activations.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a68bef31685f85fad03159a7d871a6f421e16274 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/triton/__pycache__/k_activations.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/triton/__pycache__/layer_norm.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/triton/__pycache__/layer_norm.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..18e245e06dbabd1e09749cfdc8b3d86a80b6f696 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/triton/__pycache__/layer_norm.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/triton/__pycache__/linear.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/triton/__pycache__/linear.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..791f5d275efa39f54ef7f45d28403651000ed8f7 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/triton/__pycache__/linear.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/triton/__pycache__/mlp.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/triton/__pycache__/mlp.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..44b2069223bbc737827ace8a3edc47445d621605 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/triton/__pycache__/mlp.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/triton/__pycache__/rotary.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/triton/__pycache__/rotary.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c459133d7efef88d3c7943cca94d86e018afe793 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/triton/__pycache__/rotary.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/triton/cross_entropy.py b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/triton/cross_entropy.py new file mode 100644 index 0000000000000000000000000000000000000000..8f7e9a23072c6cf3a4832b945408733edf5dec14 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/triton/cross_entropy.py @@ -0,0 +1,325 @@ +# Copyright (c) 2023, Tri Dao. + +from typing import Tuple, Optional, Union + +import torch + +import triton +import triton.language as tl + +# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for +# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent +# version of PyTorch. The following 2 lines are for backward compatibility with +# older PyTorch. +if "all_gather_into_tensor" not in dir(torch.distributed): + torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base + + +@triton.heuristics( + { + "HAS_SMOOTHING": lambda args: args["smoothing"] > 0.0, + } +) +@triton.jit +def cross_entropy_fwd_kernel( + loss_ptr, # data ptrs + lse_ptr, + z_loss_ptr, + logits_ptr, + labels_ptr, + smoothing, + logit_scale, + lse_square_scale, + ignore_index, + total_classes, + class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes + n_cols, # shapes + logits_row_stride, # strides + BLOCK_SIZE: tl.constexpr, + HAS_SMOOTHING: tl.constexpr, + # if SPLIT (e.g. tensor parallel), don't include the LSE in the loss since it's not the final LSE + SPLIT: tl.constexpr, + PRECOMPUTED_LSE: tl.constexpr, # If LSE is already computed (also no smoothing and logit_scale == 1.0) +): + row_idx = tl.program_id(0) + logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64) + sum_logits = 0.0 # For smoothing + if not PRECOMPUTED_LSE: + # Statistics for online softmax + m_i = -float("inf") + l_i = 0.0 + for col_offset in range(0, n_cols, BLOCK_SIZE): + cols = col_offset + tl.arange(0, BLOCK_SIZE) + logits = tl.load(logits_ptr + cols, mask=cols < n_cols, other=-float("inf")).to( + tl.float32 + ) * logit_scale + if HAS_SMOOTHING: + sum_logits += tl.sum(tl.where(cols < n_cols, logits, 0.0)) + m_i_new = tl.maximum(m_i, tl.max(logits)) + l_i = tl.exp(m_i - m_i_new) * l_i + tl.sum(tl.exp(logits - m_i_new)) + m_i = m_i_new + lse = tl.log(l_i) + m_i + tl.store(lse_ptr + row_idx, lse) + else: + lse = tl.load(lse_ptr + row_idx) + label_idx = tl.load(labels_ptr + row_idx) + if label_idx == ignore_index: + loss = 0.0 + z_loss = 0.0 + else: + label_idx -= class_start_idx + if label_idx >= 0 and label_idx < n_cols: + logits_label = tl.load(logits_ptr + label_idx) * logit_scale + if HAS_SMOOTHING: + loss = ( + (lse if not SPLIT else 0.0) + - smoothing * sum_logits / total_classes + - (1 - smoothing) * logits_label + ) + else: + loss = (lse if not SPLIT else 0.0) - logits_label + else: + # If label is out of bounds, we set the CE loss to 0.0. But we still want the smoothing loss + if HAS_SMOOTHING: + loss = smoothing * ((lse if not SPLIT else 0.0) - sum_logits / total_classes) + else: + loss = 0.0 + if not SPLIT: + z_loss = lse_square_scale * lse * lse + loss += z_loss + else: + z_loss = 0.0 + tl.store(loss_ptr + row_idx, loss) + if not SPLIT: + tl.store(z_loss_ptr + row_idx, z_loss) + + +@triton.heuristics( + { + "HAS_SMOOTHING": lambda args: args["smoothing"] > 0.0, + } +) +@triton.jit +def cross_entropy_bwd_kernel( + dlogits_ptr, # data ptrs + dloss_ptr, + logits_ptr, + lse_ptr, + labels_ptr, + smoothing, + logit_scale, + lse_square_scale, + ignore_index, + total_classes, + class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes + n_cols, # shapes + logits_row_stride, # strides + dlogits_row_stride, + dloss_row_stride, + BLOCK_SIZE: tl.constexpr, + HAS_SMOOTHING: tl.constexpr, +): + row_idx = tl.program_id(0) + col_block_idx = tl.program_id(1) + logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64) + dlogits_ptr = dlogits_ptr + row_idx * dlogits_row_stride.to(tl.int64) + col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + label_idx = tl.load(labels_ptr + row_idx) + if label_idx != ignore_index: + dloss = tl.load(dloss_ptr + row_idx * dloss_row_stride) + else: + dloss = 0.0 + logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to( + tl.float32 + ) * logit_scale + lse = tl.load(lse_ptr + row_idx) + probs = tl.exp(logits - lse) + probs += 2.0 * lse_square_scale * lse * probs + label_idx -= class_start_idx + if HAS_SMOOTHING: + smooth_positive = 1.0 - smoothing + smooth_negative = smoothing / total_classes + probs = tl.where(col_offsets == label_idx, probs - smooth_positive, probs) - smooth_negative + else: + probs = tl.where(col_offsets == label_idx, probs - 1.0, probs) + tl.store(dlogits_ptr + col_offsets, (dloss * logit_scale) * probs, mask=col_offsets < n_cols) + + +class CrossEntropyLoss(torch.autograd.Function): + + @staticmethod + def forward( + ctx, + logits, + labels, + precomputed_lse=None, + smoothing=0.0, + logit_scale=1.0, + lse_square_scale=0.0, + ignore_index=-100, + inplace_backward=False, + process_group=None, + ): + n_rows, n_cols = logits.shape + assert labels.shape == (n_rows,) + world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group) + total_classes = world_size * n_cols + rank = 0 if process_group is None else torch.distributed.get_rank(process_group) + class_start_idx = rank * n_cols + use_precomputed_lse = precomputed_lse is not None and logit_scale == 1.0 and smoothing == 0.0 + + if logits.stride(-1) != 1: + logits = logits.contiguous() + MAX_BLOCK_SIZE = 16 * 1024 + BLOCK_SIZE = min(triton.next_power_of_2(n_cols), MAX_BLOCK_SIZE) + num_warps = ( + 4 + if BLOCK_SIZE < 2048 + else (8 if BLOCK_SIZE < 8192 else (16 if BLOCK_SIZE < 128 * 1024 else 32)) + ) + losses = torch.empty(n_rows, dtype=torch.float, device=logits.device) + if use_precomputed_lse: + assert precomputed_lse.shape == (n_rows,) + lse = precomputed_lse.contiguous() + else: + lse = torch.empty(n_rows, dtype=torch.float, device=logits.device) + z_losses = torch.empty(n_rows, dtype=torch.float, device=logits.device) + # Need this, otherwise Triton tries to launch from cuda:0 and we get + # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?) + with torch.cuda.device(logits.device.index): + cross_entropy_fwd_kernel[(n_rows,)]( + losses, # data ptrs + lse, + z_losses, + logits, + labels, + smoothing, + logit_scale, + lse_square_scale, + ignore_index, + total_classes, + class_start_idx, + n_cols, # shapes + logits.stride(0), # strides + BLOCK_SIZE=BLOCK_SIZE, # constants + SPLIT=world_size > 1, + PRECOMPUTED_LSE=use_precomputed_lse, + num_warps=num_warps, + ) + + if world_size > 1: + # If there's no smoothing, if labels are in the vocab of this partition, losses contains + # - predicted logit, and 0 otherwise. + # If there's smoothing=0.1, for labels in the vocab of this partition, losses contains + # -0.9 * predicted logit - 0.1 * sum logit / total_classes. + # For labels not in the vocab of this partition, losses contains + # -0.1 * sum logit / total_classes. + if world_size > 1: + lse_allgather = torch.empty(world_size, n_rows, dtype=lse.dtype, device=lse.device) + torch.distributed.all_gather_into_tensor(lse_allgather, lse, group=process_group) + handle_losses = torch.distributed.all_reduce( + losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True + ) + lse = torch.logsumexp(lse_allgather, dim=0) + handle_losses.wait() + # After the allreduce, if there's no smoothing, the total losses are - predicted_logit, + # we just have to add the (global) lse. + # If there's smoothing=0.1, the total losses are + # -0.9 * predicted_logit - 0.1 * sum logit / total_classes. + # Again, we just have to add the (global) lse. + losses += lse + if lse_square_scale != 0.0: + z_losses = lse_square_scale * lse.square() + z_losses.masked_fill_(labels == ignore_index, 0.0) + losses += z_losses + else: + z_losses = torch.zeros_like(losses) + losses.masked_fill_(labels == ignore_index, 0.0) + + ctx.save_for_backward(logits, lse, labels) + ctx.mark_non_differentiable(z_losses) + ctx.smoothing = smoothing + ctx.logit_scale = logit_scale + ctx.lse_square_scale = lse_square_scale + ctx.ignore_index = ignore_index + ctx.total_classes = total_classes + ctx.class_start_idx = class_start_idx + ctx.inplace_backward = inplace_backward + + + return losses, z_losses + + @staticmethod + def backward(ctx, grad_losses, grad_z_losses): + del grad_z_losses # z_losses are only for logging. + + logits, lse, labels = ctx.saved_tensors + dlogits = logits if ctx.inplace_backward else torch.empty_like(logits) + n_rows, n_cols = logits.shape + BLOCK_SIZE = min(triton.next_power_of_2(n_cols), 4 * 1024) + num_warps = 4 if BLOCK_SIZE < 2048 else (8 if BLOCK_SIZE < 8192 else 16) + grid = lambda META: (n_rows, triton.cdiv(n_cols, META["BLOCK_SIZE"])) # noqa + # Need this, otherwise Triton tries to launch from cuda:0 and we get + # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?) + with torch.cuda.device(logits.device.index): + cross_entropy_bwd_kernel[grid]( + dlogits, # data ptrs + grad_losses, + logits, + lse, + labels, + ctx.smoothing, + ctx.logit_scale, + ctx.lse_square_scale, + ctx.ignore_index, + ctx.total_classes, + ctx.class_start_idx, + n_cols, # shapes + logits.stride(0), # strides + dlogits.stride(0), + grad_losses.stride(0), + BLOCK_SIZE=BLOCK_SIZE, # constants + num_warps=num_warps, + ) + return dlogits, None, None, None, None, None, None, None, None, None + + +def cross_entropy_loss( + logits: torch.Tensor, + labels: torch.Tensor, + precomputed_lse: Optional[torch.Tensor] = None, + label_smoothing: float = 0.0, + logit_scale: float = 1.0, + lse_square_scale: float = 0.0, + ignore_index=-100, + inplace_backward: bool = False, + process_group=None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Arguments: + logits: (batch, vocab_size) + labels: (batch,) + label_smoothing: float + logit_scale: float. Multiply logits by this scale before calculating the loss. + lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss. + This is also referred to as "z-loss". + ignore_index: int. If labels == ignore_index, the loss is set to 0.0. + inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits. + This saves memory. + process_group: if not None, we're doing Tensor Parallel: each process is responsible for + one part of the vocab. The loss will be aggregated across processes. + Returns: + losses: (batch,), float + z_losses: (batch,), float + """ + return CrossEntropyLoss.apply( + logits, + labels, + precomputed_lse, + label_smoothing, + logit_scale, + lse_square_scale, + ignore_index, + inplace_backward, + process_group, + ) diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/triton/k_activations.py b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/triton/k_activations.py new file mode 100644 index 0000000000000000000000000000000000000000..efb83c358eb4a85d069ee340a3c83f418f9a805b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/triton/k_activations.py @@ -0,0 +1,162 @@ +# Adapted from https://github.com/facebookresearch/xformers/blob/main/xformers/triton/k_activations.py +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import math +from enum import Enum +from typing import Optional + +import triton +import triton.language as tl + +_sqrt2pi = math.sqrt(2.0 / math.pi) +_sqrt1_2 = math.sqrt(1.0 / 2) +_gaussian_pdf_normalization = 1.0 / math.sqrt(2 * math.pi) + + +class Activation(str, Enum): + SquaredReLU = "squared_relu" + GeLU = "gelu" + GeLUApprox = "gelu_approx" + LeakyReLU = "leaky_relu" + ReLU = "relu" + + +def get_triton_activation_kernel(activation: Optional[Activation]): + return ( + { + Activation.ReLU: relu, + Activation.LeakyReLU: leaky_relu, + Activation.GeLU: gelu, + Activation.GeLUApprox: gelu_approx, + Activation.SquaredReLU: squared_relu, + }[activation] + if activation + else None + ) + + +def get_triton_activation_bwd_kernel(activation: Optional[Activation]): + return ( + { + Activation.ReLU: relu_grad, + Activation.LeakyReLU: leaky_relu_grad, + Activation.GeLU: gelu_grad, + Activation.GeLUApprox: gelu_approx_grad, + Activation.SquaredReLU: squared_relu_grad, + }[activation] + if activation + else None + ) + + +@triton.jit +def tanh(x): + # Tanh is just a scaled sigmoid + return 2 * tl.sigmoid(2 * x) - 1 + + +@triton.jit +def cosh(x): + exp_x = tl.exp(x) + return (exp_x + 1.0 / exp_x) * 0.5 + + +# a Triton implementation of the most used activations +# See for instance http://arxiv.org/abs/1606.08415 for an overview + +# ReLU +@triton.jit +def relu(x): + """ + ReLU_ activation function + + .. _ReLU: https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html + """ + zero = 0.0 + return tl.where(x >= 0, x, zero.to(x.dtype)) + + +@triton.jit +def relu_grad(x): + # ReLU is different from other activations + # in that it does not require the input to retrospectively compute its gradient + # here the input is the downstream gradient, and we return the upstream gradient directly + zero = 0.0 + one = 1.0 + return tl.where(x >= 0, one.to(x.dtype), zero.to(x.dtype)) + + +@triton.jit +def squared_relu(x): + """ + Squared ReLU activation, as proposed in the Primer_ paper. + + .. _Primer: https://arxiv.org/abs/2109.08668 + """ + x_ = relu(x) + return (x_ * x_).to(x.dtype) + + +@triton.jit +def squared_relu_grad(x): + return tl.where(x >= 0, 2.0 * x, 0.0) + + +# Leaky ReLU +@triton.jit +def leaky_relu(x): + """ + LeakyReLU_ activation + + .. _LeakyReLU: https://pytorch.org/docs/stable/generated/torch.nn.LeakyReLU.html + """ + scale = 0.01 + 0.0 + scale = scale.to(x.dtype) + return tl.where(x >= 0, x, scale * x) + + +@triton.jit +def leaky_relu_grad(x): + min_grad = 0.01 + max_grad = 1 + + min_grad = min_grad.to(x.dtype) + max_grad = max_grad.to(x.dtype) + + return tl.where(x >= 0, max_grad, min_grad) + + +@triton.jit +def gelu(x): + """Gaussian Error Linear Unit (GELU)""" + return x * 0.5 * (1.0 + tl.libdevice.erf(x * _sqrt1_2)) + + +@triton.jit +def gelu_grad(x): + cdf = 0.5 * (1.0 + tl.libdevice.erf(x * _sqrt1_2)) + pdf = tl.exp(-0.5 * x * x) * _gaussian_pdf_normalization + return cdf + x * pdf + + +@triton.jit +def gelu_approx(x): + """ + GeLU_ activation - Gaussian error linear unit, with tanh approximation + + .. _GeLU: https://arxiv.org/pdf/1606.08415.pdf + """ + return 0.5 * x * (1.0 + tanh(_sqrt2pi * x * (1.0 + 0.044715 * x * x))) + + +@triton.jit +def gelu_approx_grad(x): + # CREDITS: Fast implementation proposed in + # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/fused_bias_gelu.py#L30 + tanh_out = tanh(0.79788456 * x * (1 + 0.044715 * x * x)) + return 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * ( + 1 + tanh_out + ) diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/triton/layer_norm.py b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/triton/layer_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..addffe1f18585b59a8c22f20d3708ffb46d6bf34 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/triton/layer_norm.py @@ -0,0 +1,1112 @@ +# Copyright (c) 2024, Tri Dao. +# Implement dropout + residual + layer_norm / rms_norm. + +# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html +# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate. +# This is faster for dimensions up to 8k, but after that it's much slower due to register spilling. +# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine. + +import math + +import torch +import torch.nn.functional as F +from torch.cuda.amp import custom_fwd, custom_bwd + +import triton +import triton.language as tl + + +def layer_norm_ref( + x, + weight, + bias, + residual=None, + x1=None, + weight1=None, + bias1=None, + eps=1e-6, + dropout_p=0.0, + rowscale=None, + prenorm=False, + dropout_mask=None, + dropout_mask1=None, + upcast=False, +): + dtype = x.dtype + if upcast: + x = x.float() + weight = weight.float() + bias = bias.float() if bias is not None else None + residual = residual.float() if residual is not None else residual + x1 = x1.float() if x1 is not None else None + weight1 = weight1.float() if weight1 is not None else None + bias1 = bias1.float() if bias1 is not None else None + if x1 is not None: + assert rowscale is None, "rowscale is not supported with parallel LayerNorm" + if rowscale is not None: + x = x * rowscale[..., None] + if dropout_p > 0.0: + if dropout_mask is not None: + x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p) + else: + x = F.dropout(x, p=dropout_p) + if x1 is not None: + if dropout_mask1 is not None: + x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p) + else: + x1 = F.dropout(x1, p=dropout_p) + if x1 is not None: + x = x + x1 + if residual is not None: + x = (x + residual).to(x.dtype) + out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to( + dtype + ) + if weight1 is None: + return out if not prenorm else (out, x) + else: + out1 = F.layer_norm( + x.to(weight1.dtype), x.shape[-1:], weight=weight1, bias=bias1, eps=eps + ).to(dtype) + return (out, out1) if not prenorm else (out, out1, x) + + +def rms_norm_ref( + x, + weight, + bias, + residual=None, + x1=None, + weight1=None, + bias1=None, + eps=1e-6, + dropout_p=0.0, + rowscale=None, + prenorm=False, + dropout_mask=None, + dropout_mask1=None, + upcast=False, +): + dtype = x.dtype + if upcast: + x = x.float() + weight = weight.float() + bias = bias.float() if bias is not None else None + residual = residual.float() if residual is not None else residual + x1 = x1.float() if x1 is not None else None + weight1 = weight1.float() if weight1 is not None else None + bias1 = bias1.float() if bias1 is not None else None + if x1 is not None: + assert rowscale is None, "rowscale is not supported with parallel LayerNorm" + if rowscale is not None: + x = x * rowscale[..., None] + if dropout_p > 0.0: + if dropout_mask is not None: + x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p) + else: + x = F.dropout(x, p=dropout_p) + if x1 is not None: + if dropout_mask1 is not None: + x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p) + else: + x1 = F.dropout(x1, p=dropout_p) + if x1 is not None: + x = x + x1 + if residual is not None: + x = (x + residual).to(x.dtype) + rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps) + out = ((x * rstd * weight) + bias if bias is not None else (x * rstd * weight)).to(dtype) + if weight1 is None: + return out if not prenorm else (out, x) + else: + out1 = ((x * rstd * weight1) + bias1 if bias1 is not None else (x * rstd * weight1)).to( + dtype + ) + return (out, out1) if not prenorm else (out, out1, x) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"], +) +# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) +# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None}) +@triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None}) +@triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None}) +@triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None}) +@triton.jit +def _layer_norm_fwd_1pass_kernel( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + RESIDUAL, # pointer to the residual + X1, + W1, + B1, + Y1, + RESIDUAL_OUT, # pointer to the residual + ROWSCALE, + SEEDS, # Dropout seeds for each row + DROPOUT_MASK, + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_res_row, + stride_res_out_row, + stride_x1_row, + stride_y1_row, + M, # number of rows in X + N, # number of columns in X + eps, # epsilon to avoid division by zero + dropout_p, # Dropout probability + IS_RMS_NORM: tl.constexpr, + BLOCK_N: tl.constexpr, + HAS_RESIDUAL: tl.constexpr, + STORE_RESIDUAL_OUT: tl.constexpr, + HAS_BIAS: tl.constexpr, + HAS_DROPOUT: tl.constexpr, + STORE_DROPOUT_MASK: tl.constexpr, + HAS_ROWSCALE: tl.constexpr, + HAS_X1: tl.constexpr, + HAS_W1: tl.constexpr, + HAS_B1: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + X += row * stride_x_row + Y += row * stride_y_row + if HAS_RESIDUAL: + RESIDUAL += row * stride_res_row + if STORE_RESIDUAL_OUT: + RESIDUAL_OUT += row * stride_res_out_row + if HAS_X1: + X1 += row * stride_x1_row + if HAS_W1: + Y1 += row * stride_y1_row + # Compute mean and variance + cols = tl.arange(0, BLOCK_N) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + if HAS_ROWSCALE: + rowscale = tl.load(ROWSCALE + row).to(tl.float32) + x *= rowscale + if HAS_DROPOUT: + # Compute dropout mask + # 7 rounds is good enough, and reduces register pressure + keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p + x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0) + if STORE_DROPOUT_MASK: + tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N) + if HAS_X1: + x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32) + if HAS_ROWSCALE: + rowscale = tl.load(ROWSCALE + M + row).to(tl.float32) + x1 *= rowscale + if HAS_DROPOUT: + # Compute dropout mask + # 7 rounds is good enough, and reduces register pressure + keep_mask = ( + tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p + ) + x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0) + if STORE_DROPOUT_MASK: + tl.store(DROPOUT_MASK + (M + row) * N + cols, keep_mask, mask=cols < N) + x += x1 + if HAS_RESIDUAL: + residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32) + x += residual + if STORE_RESIDUAL_OUT: + tl.store(RESIDUAL_OUT + cols, x, mask=cols < N) + if not IS_RMS_NORM: + mean = tl.sum(x, axis=0) / N + tl.store(Mean + row, mean) + xbar = tl.where(cols < N, x - mean, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + else: + xbar = tl.where(cols < N, x, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + tl.store(Rstd + row, rstd) + # Normalize and apply linear transformation + mask = cols < N + w = tl.load(W + cols, mask=mask).to(tl.float32) + if HAS_BIAS: + b = tl.load(B + cols, mask=mask).to(tl.float32) + x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + y = x_hat * w + b if HAS_BIAS else x_hat * w + # Write output + tl.store(Y + cols, y, mask=mask) + if HAS_W1: + w1 = tl.load(W1 + cols, mask=mask).to(tl.float32) + if HAS_B1: + b1 = tl.load(B1 + cols, mask=mask).to(tl.float32) + y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1 + tl.store(Y1 + cols, y1, mask=mask) + + +def _layer_norm_fwd( + x, + weight, + bias, + eps, + residual=None, + x1=None, + weight1=None, + bias1=None, + dropout_p=0.0, + rowscale=None, + out_dtype=None, + residual_dtype=None, + is_rms_norm=False, + return_dropout_mask=False, + out=None, + residual_out=None +): + if residual is not None: + residual_dtype = residual.dtype + M, N = x.shape + assert x.stride(-1) == 1 + if residual is not None: + assert residual.stride(-1) == 1 + assert residual.shape == (M, N) + assert weight.shape == (N,) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N,) + if x1 is not None: + assert x1.shape == x.shape + assert rowscale is None + assert x1.stride(-1) == 1 + if weight1 is not None: + assert weight1.shape == (N,) + assert weight1.stride(-1) == 1 + if bias1 is not None: + assert bias1.shape == (N,) + assert bias1.stride(-1) == 1 + if rowscale is not None: + assert rowscale.is_contiguous() + assert rowscale.shape == (M,) + # allocate output + if out is None: + out = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype) + else: + assert out.shape == x.shape + assert out.stride(-1) == 1 + if weight1 is not None: + y1 = torch.empty_like(out) + assert y1.stride(-1) == 1 + else: + y1 = None + if ( + residual is not None + or (residual_dtype is not None and residual_dtype != x.dtype) + or dropout_p > 0.0 + or rowscale is not None + or x1 is not None + ): + if residual_out is None: + residual_out = torch.empty( + M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype + ) + else: + assert residual_out.shape == x.shape + assert residual_out.stride(-1) == 1 + else: + residual_out = None + mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None + rstd = torch.empty((M,), dtype=torch.float32, device=x.device) + if dropout_p > 0.0: + seeds = torch.randint( + 2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64 + ) + else: + seeds = None + if return_dropout_mask and dropout_p > 0.0: + dropout_mask = torch.empty(M if x1 is None else 2 * M, N, device=x.device, dtype=torch.bool) + else: + dropout_mask = None + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + with torch.cuda.device(x.device.index): + _layer_norm_fwd_1pass_kernel[(M,)]( + x, + out, + weight, + bias, + residual, + x1, + weight1, + bias1, + y1, + residual_out, + rowscale, + seeds, + dropout_mask, + mean, + rstd, + x.stride(0), + out.stride(0), + residual.stride(0) if residual is not None else 0, + residual_out.stride(0) if residual_out is not None else 0, + x1.stride(0) if x1 is not None else 0, + y1.stride(0) if y1 is not None else 0, + M, + N, + eps, + dropout_p, + is_rms_norm, + BLOCK_N, + residual is not None, + residual_out is not None, + bias is not None, + dropout_p > 0.0, + dropout_mask is not None, + rowscale is not None, + ) + # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0 + if dropout_mask is not None and x1 is not None: + dropout_mask, dropout_mask1 = dropout_mask.tensor_split(2, dim=0) + else: + dropout_mask1 = None + return ( + out, + y1, + mean, + rstd, + residual_out if residual_out is not None else x, + seeds, + dropout_mask, + dropout_mask1, + ) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS", "HAS_DROPOUT"], +) +# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) +# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None}) +# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None}) +@triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None}) +@triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None}) +@triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None}) +@triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None}) +@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None}) +@triton.jit +def _layer_norm_bwd_kernel( + X, # pointer to the input + W, # pointer to the weights + B, # pointer to the biases + Y, # pointer to the output to be recomputed + DY, # pointer to the output gradient + DX, # pointer to the input gradient + DW, # pointer to the partial sum of weights gradient + DB, # pointer to the partial sum of biases gradient + DRESIDUAL, + W1, + DY1, + DX1, + DW1, + DB1, + DRESIDUAL_IN, + ROWSCALE, + SEEDS, + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_dy_row, + stride_dx_row, + stride_dres_row, + stride_dy1_row, + stride_dx1_row, + stride_dres_in_row, + M, # number of rows in X + N, # number of columns in X + eps, # epsilon to avoid division by zero + dropout_p, + rows_per_program, + IS_RMS_NORM: tl.constexpr, + BLOCK_N: tl.constexpr, + HAS_DRESIDUAL: tl.constexpr, + STORE_DRESIDUAL: tl.constexpr, + HAS_BIAS: tl.constexpr, + HAS_DROPOUT: tl.constexpr, + HAS_ROWSCALE: tl.constexpr, + HAS_DY1: tl.constexpr, + HAS_DX1: tl.constexpr, + HAS_B1: tl.constexpr, + RECOMPUTE_OUTPUT: tl.constexpr, +): + # Map the program id to the elements of X, DX, and DY it should compute. + row_block_id = tl.program_id(0) + row_start = row_block_id * rows_per_program + # Do not early exit if row_start >= M, because we need to write DW and DB + cols = tl.arange(0, BLOCK_N) + mask = cols < N + X += row_start * stride_x_row + if HAS_DRESIDUAL: + DRESIDUAL += row_start * stride_dres_row + if STORE_DRESIDUAL: + DRESIDUAL_IN += row_start * stride_dres_in_row + DY += row_start * stride_dy_row + DX += row_start * stride_dx_row + if HAS_DY1: + DY1 += row_start * stride_dy1_row + if HAS_DX1: + DX1 += row_start * stride_dx1_row + if RECOMPUTE_OUTPUT: + Y += row_start * stride_y_row + w = tl.load(W + cols, mask=mask).to(tl.float32) + if RECOMPUTE_OUTPUT and HAS_BIAS: + b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32) + if HAS_DY1: + w1 = tl.load(W1 + cols, mask=mask).to(tl.float32) + dw = tl.zeros((BLOCK_N,), dtype=tl.float32) + if HAS_BIAS: + db = tl.zeros((BLOCK_N,), dtype=tl.float32) + if HAS_DY1: + dw1 = tl.zeros((BLOCK_N,), dtype=tl.float32) + if HAS_B1: + db1 = tl.zeros((BLOCK_N,), dtype=tl.float32) + row_end = min((row_block_id + 1) * rows_per_program, M) + for row in range(row_start, row_end): + # Load data to SRAM + x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) + dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) + if HAS_DY1: + dy1 = tl.load(DY1 + cols, mask=mask, other=0).to(tl.float32) + if not IS_RMS_NORM: + mean = tl.load(Mean + row) + rstd = tl.load(Rstd + row) + # Compute dx + xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + xhat = tl.where(mask, xhat, 0.0) + if RECOMPUTE_OUTPUT: + y = xhat * w + b if HAS_BIAS else xhat * w + tl.store(Y + cols, y, mask=mask) + wdy = w * dy + dw += dy * xhat + if HAS_BIAS: + db += dy + if HAS_DY1: + wdy += w1 * dy1 + dw1 += dy1 * xhat + if HAS_B1: + db1 += dy1 + if not IS_RMS_NORM: + c1 = tl.sum(xhat * wdy, axis=0) / N + c2 = tl.sum(wdy, axis=0) / N + dx = (wdy - (xhat * c1 + c2)) * rstd + else: + c1 = tl.sum(xhat * wdy, axis=0) / N + dx = (wdy - xhat * c1) * rstd + if HAS_DRESIDUAL: + dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32) + dx += dres + # Write dx + if STORE_DRESIDUAL: + tl.store(DRESIDUAL_IN + cols, dx, mask=mask) + if HAS_DX1: + if HAS_DROPOUT: + keep_mask = ( + tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p + ) + dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0) + else: + dx1 = dx + tl.store(DX1 + cols, dx1, mask=mask) + if HAS_DROPOUT: + keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p + dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0) + if HAS_ROWSCALE: + rowscale = tl.load(ROWSCALE + row).to(tl.float32) + dx *= rowscale + tl.store(DX + cols, dx, mask=mask) + + X += stride_x_row + if HAS_DRESIDUAL: + DRESIDUAL += stride_dres_row + if STORE_DRESIDUAL: + DRESIDUAL_IN += stride_dres_in_row + if RECOMPUTE_OUTPUT: + Y += stride_y_row + DY += stride_dy_row + DX += stride_dx_row + if HAS_DY1: + DY1 += stride_dy1_row + if HAS_DX1: + DX1 += stride_dx1_row + tl.store(DW + row_block_id * N + cols, dw, mask=mask) + if HAS_BIAS: + tl.store(DB + row_block_id * N + cols, db, mask=mask) + if HAS_DY1: + tl.store(DW1 + row_block_id * N + cols, dw1, mask=mask) + if HAS_B1: + tl.store(DB1 + row_block_id * N + cols, db1, mask=mask) + + +def _layer_norm_bwd( + dy, + x, + weight, + bias, + eps, + mean, + rstd, + dresidual=None, + dy1=None, + weight1=None, + bias1=None, + seeds=None, + dropout_p=0.0, + rowscale=None, + has_residual=False, + has_x1=False, + is_rms_norm=False, + x_dtype=None, + recompute_output=False, +): + M, N = x.shape + assert x.stride(-1) == 1 + assert dy.stride(-1) == 1 + assert dy.shape == (M, N) + if dresidual is not None: + assert dresidual.stride(-1) == 1 + assert dresidual.shape == (M, N) + assert weight.shape == (N,) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N,) + if dy1 is not None: + assert weight1 is not None + assert dy1.shape == dy.shape + assert dy1.stride(-1) == 1 + if weight1 is not None: + assert weight1.shape == (N,) + assert weight1.stride(-1) == 1 + if bias1 is not None: + assert bias1.shape == (N,) + assert bias1.stride(-1) == 1 + if seeds is not None: + assert seeds.is_contiguous() + assert seeds.shape == (M if not has_x1 else M * 2,) + if rowscale is not None: + assert rowscale.is_contiguous() + assert rowscale.shape == (M,) + # allocate output + dx = ( + torch.empty_like(x) + if x_dtype is None + else torch.empty(M, N, dtype=x_dtype, device=x.device) + ) + dresidual_in = ( + torch.empty_like(x) + if has_residual + and (dx.dtype != x.dtype or dropout_p > 0.0 or rowscale is not None or has_x1) + else None + ) + dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None + y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None + if recompute_output: + assert weight1 is None, "recompute_output is not supported with parallel LayerNorm" + + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count + _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device) + _db = ( + torch.empty((sm_count, N), dtype=torch.float32, device=bias.device) + if bias is not None + else None + ) + _dw1 = torch.empty_like(_dw) if weight1 is not None else None + _db1 = torch.empty_like(_db) if bias1 is not None else None + rows_per_program = math.ceil(M / sm_count) + grid = (sm_count,) + with torch.cuda.device(x.device.index): + _layer_norm_bwd_kernel[grid]( + x, + weight, + bias, + y, + dy, + dx, + _dw, + _db, + dresidual, + weight1, + dy1, + dx1, + _dw1, + _db1, + dresidual_in, + rowscale, + seeds, + mean, + rstd, + x.stride(0), + 0 if not recompute_output else y.stride(0), + dy.stride(0), + dx.stride(0), + dresidual.stride(0) if dresidual is not None else 0, + dy1.stride(0) if dy1 is not None else 0, + dx1.stride(0) if dx1 is not None else 0, + dresidual_in.stride(0) if dresidual_in is not None else 0, + M, + N, + eps, + dropout_p, + rows_per_program, + is_rms_norm, + BLOCK_N, + dresidual is not None, + dresidual_in is not None, + bias is not None, + dropout_p > 0.0, + ) + dw = _dw.sum(0).to(weight.dtype) + db = _db.sum(0).to(bias.dtype) if bias is not None else None + dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None + db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None + # Don't need to compute dresidual_in separately in this case + if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None: + dresidual_in = dx + if has_x1 and dropout_p == 0.0: + dx1 = dx + return ( + (dx, dw, db, dresidual_in, dx1, dw1, db1) + if not recompute_output + else (dx, dw, db, dresidual_in, dx1, dw1, db1, y) + ) + + +class LayerNormFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + weight, + bias, + residual=None, + x1=None, + weight1=None, + bias1=None, + eps=1e-6, + dropout_p=0.0, + rowscale=None, + prenorm=False, + residual_in_fp32=False, + is_rms_norm=False, + return_dropout_mask=False, + out=None, + residual_out=None + ): + x_shape_og = x.shape + # reshape input data into 2D tensor + x = x.reshape(-1, x.shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + if residual is not None: + assert residual.shape == x_shape_og + residual = residual.reshape(-1, residual.shape[-1]) + if residual.stride(-1) != 1: + residual = residual.contiguous() + if x1 is not None: + assert x1.shape == x_shape_og + assert rowscale is None, "rowscale is not supported with parallel LayerNorm" + x1 = x1.reshape(-1, x1.shape[-1]) + if x1.stride(-1) != 1: + x1 = x1.contiguous() + weight = weight.contiguous() + if bias is not None: + bias = bias.contiguous() + if weight1 is not None: + weight1 = weight1.contiguous() + if bias1 is not None: + bias1 = bias1.contiguous() + if rowscale is not None: + rowscale = rowscale.reshape(-1).contiguous() + residual_dtype = ( + residual.dtype + if residual is not None + else (torch.float32 if residual_in_fp32 else None) + ) + if out is not None: + out = out.reshape(-1, out.shape[-1]) + if residual_out is not None: + residual_out = residual_out.reshape(-1, residual_out.shape[-1]) + y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd( + x, + weight, + bias, + eps, + residual, + x1, + weight1, + bias1, + dropout_p=dropout_p, + rowscale=rowscale, + residual_dtype=residual_dtype, + is_rms_norm=is_rms_norm, + return_dropout_mask=return_dropout_mask, + out=out, + residual_out=residual_out + ) + ctx.save_for_backward( + residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd + ) + ctx.x_shape_og = x_shape_og + ctx.eps = eps + ctx.dropout_p = dropout_p + ctx.is_rms_norm = is_rms_norm + ctx.has_residual = residual is not None + ctx.has_x1 = x1 is not None + ctx.prenorm = prenorm + ctx.x_dtype = x.dtype + y = y.reshape(x_shape_og) + y1 = y1.reshape(x_shape_og) if y1 is not None else None + residual_out = residual_out.reshape(x_shape_og) if residual_out is not None else None + dropout_mask = dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None + dropout_mask1 = dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None + if not return_dropout_mask: + if weight1 is None: + return y if not prenorm else (y, residual_out) + else: + return (y, y1) if not prenorm else (y, y1, residual_out) + else: + if weight1 is None: + return ( + (y, dropout_mask, dropout_mask1) + if not prenorm + else (y, residual_out, dropout_mask, dropout_mask1) + ) + else: + return ( + (y, y1, dropout_mask, dropout_mask1) + if not prenorm + else (y, y1, residual_out, dropout_mask, dropout_mask1) + ) + + @staticmethod + def backward(ctx, dy, *args): + x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensors + dy = dy.reshape(-1, dy.shape[-1]) + if dy.stride(-1) != 1: + dy = dy.contiguous() + assert dy.shape == x.shape + if weight1 is not None: + dy1, args = args[0], args[1:] + dy1 = dy1.reshape(-1, dy1.shape[-1]) + if dy1.stride(-1) != 1: + dy1 = dy1.contiguous() + assert dy1.shape == x.shape + else: + dy1 = None + if ctx.prenorm: + dresidual = args[0] + dresidual = dresidual.reshape(-1, dresidual.shape[-1]) + if dresidual.stride(-1) != 1: + dresidual = dresidual.contiguous() + assert dresidual.shape == x.shape + else: + dresidual = None + dx, dw, db, dresidual_in, dx1, dw1, db1 = _layer_norm_bwd( + dy, + x, + weight, + bias, + ctx.eps, + mean, + rstd, + dresidual, + dy1, + weight1, + bias1, + seeds, + ctx.dropout_p, + rowscale, + ctx.has_residual, + ctx.has_x1, + ctx.is_rms_norm, + x_dtype=ctx.x_dtype, + ) + return ( + dx.reshape(ctx.x_shape_og), + dw, + db, + dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, + dx1.reshape(ctx.x_shape_og) if dx1 is not None else None, + dw1, + db1, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + + +def layer_norm_fn( + x, + weight, + bias, + residual=None, + x1=None, + weight1=None, + bias1=None, + eps=1e-6, + dropout_p=0.0, + rowscale=None, + prenorm=False, + residual_in_fp32=False, + is_rms_norm=False, + return_dropout_mask=False, + out=None, + residual_out=None +): + return LayerNormFn.apply( + x, + weight, + bias, + residual, + x1, + weight1, + bias1, + eps, + dropout_p, + rowscale, + prenorm, + residual_in_fp32, + is_rms_norm, + return_dropout_mask, + out, + residual_out + ) + + +def rms_norm_fn( + x, + weight, + bias, + residual=None, + x1=None, + weight1=None, + bias1=None, + eps=1e-6, + dropout_p=0.0, + rowscale=None, + prenorm=False, + residual_in_fp32=False, + return_dropout_mask=False, + out=None, + residual_out=None +): + return LayerNormFn.apply( + x, + weight, + bias, + residual, + x1, + weight1, + bias1, + eps, + dropout_p, + rowscale, + prenorm, + residual_in_fp32, + True, + return_dropout_mask, + out, + residual_out + ) + + +class RMSNorm(torch.nn.Module): + + def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, device=None, dtype=None): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.eps = eps + if dropout_p > 0.0: + self.drop = torch.nn.Dropout(dropout_p) + else: + self.drop = None + self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self): + torch.nn.init.ones_(self.weight) + + def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False): + return rms_norm_fn( + x, + self.weight, + self.bias, + residual=residual, + eps=self.eps, + dropout_p=self.drop.p if self.drop is not None and self.training else 0.0, + prenorm=prenorm, + residual_in_fp32=residual_in_fp32, + ) + + +class LayerNormLinearFn(torch.autograd.Function): + @staticmethod + @custom_fwd + def forward( + ctx, + x, + norm_weight, + norm_bias, + linear_weight, + linear_bias, + residual=None, + eps=1e-6, + prenorm=False, + residual_in_fp32=False, + is_rms_norm=False, + ): + x_shape_og = x.shape + # reshape input data into 2D tensor + x = x.reshape(-1, x.shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + if residual is not None: + assert residual.shape == x_shape_og + residual = residual.reshape(-1, residual.shape[-1]) + if residual.stride(-1) != 1: + residual = residual.contiguous() + norm_weight = norm_weight.contiguous() + if norm_bias is not None: + norm_bias = norm_bias.contiguous() + residual_dtype = ( + residual.dtype + if residual is not None + else (torch.float32 if residual_in_fp32 else None) + ) + y, _, mean, rstd, residual_out, *rest = _layer_norm_fwd( + x, + norm_weight, + norm_bias, + eps, + residual, + out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype(), + residual_dtype=residual_dtype, + is_rms_norm=is_rms_norm, + ) + y = y.reshape(x_shape_og) + dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype + linear_weight = linear_weight.to(dtype) + linear_bias = linear_bias.to(dtype) if linear_bias is not None else None + out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias) + # We don't store y, will be recomputed in the backward pass to save memory + ctx.save_for_backward(residual_out, norm_weight, norm_bias, linear_weight, mean, rstd) + ctx.x_shape_og = x_shape_og + ctx.eps = eps + ctx.is_rms_norm = is_rms_norm + ctx.has_residual = residual is not None + ctx.prenorm = prenorm + ctx.x_dtype = x.dtype + ctx.linear_bias_is_none = linear_bias is None + return out if not prenorm else (out, residual_out.reshape(x_shape_og)) + + @staticmethod + @custom_bwd + def backward(ctx, dout, *args): + x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors + dout = dout.reshape(-1, dout.shape[-1]) + dy = F.linear(dout, linear_weight.t()) + dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0) + if dy.stride(-1) != 1: + dy = dy.contiguous() + assert dy.shape == x.shape + if ctx.prenorm: + dresidual = args[0] + dresidual = dresidual.reshape(-1, dresidual.shape[-1]) + if dresidual.stride(-1) != 1: + dresidual = dresidual.contiguous() + assert dresidual.shape == x.shape + else: + dresidual = None + dx, dnorm_weight, dnorm_bias, dresidual_in, _, _, _, y = _layer_norm_bwd( + dy, + x, + norm_weight, + norm_bias, + ctx.eps, + mean, + rstd, + dresidual=dresidual, + has_residual=ctx.has_residual, + is_rms_norm=ctx.is_rms_norm, + x_dtype=ctx.x_dtype, + recompute_output=True, + ) + dlinear_weight = torch.einsum("bo,bi->oi", dout, y) + return ( + dx.reshape(ctx.x_shape_og), + dnorm_weight, + dnorm_bias, + dlinear_weight, + dlinear_bias, + dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, + None, + None, + None, + None, + ) + + +def layer_norm_linear_fn( + x, + norm_weight, + norm_bias, + linear_weight, + linear_bias, + residual=None, + eps=1e-6, + prenorm=False, + residual_in_fp32=False, + is_rms_norm=False, +): + return LayerNormLinearFn.apply( + x, + norm_weight, + norm_bias, + linear_weight, + linear_bias, + residual, + eps, + prenorm, + residual_in_fp32, + is_rms_norm, + ) diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/triton/linear.py b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/triton/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..a8966dbc345ab0e593df0124451ee7be3dae131a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/triton/linear.py @@ -0,0 +1,594 @@ +# Adapted from https://github.com/ELS-RD/kernl/blob/main/src/kernl/implementations/linear_layer.py +# and https://github.com/openai/triton/blob/master/python/triton/ops/matmul.py +from typing import Optional + +import torch +import triton +import triton.language as tl +from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time + +from flash_attn.ops.triton.k_activations import ( + gelu, + gelu_approx, + gelu_approx_grad, + gelu_grad, + squared_relu, + squared_relu_grad, +) + +# CREDITS: Initially inspired by the Triton tutorial on matrix multiplications + + +def init_to_zero(name): + return lambda nargs: nargs[name].zero_() + + +def get_configs_io_bound(): + configs = [] + for num_stages in [2, 3, 4, 5, 6]: + for block_m in [16, 32]: + for block_k in [32, 64]: + for block_n in [32, 64, 128, 256]: + num_warps = 2 if block_n <= 64 else 4 + configs.append( + triton.Config( + { + "BLOCK_M": block_m, + "BLOCK_N": block_n, + "BLOCK_K": block_k, + "SPLIT_K": 1, + }, + num_stages=num_stages, + num_warps=num_warps, + ) + ) + # split_k not used + # for split_k in [2, 4, 8, 16]: + # configs.append(triton.Config( + # {'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k}, + # num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C'))) + return configs + + +@triton.autotune( + configs=[ + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8 + ), + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8 + ), + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2 + ), + # good for int8 + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, + num_stages=3, + num_warps=8, + ), + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, + num_stages=3, + num_warps=8, + ), + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2 + ), + ] + + get_configs_io_bound(), + key=["CACHE_KEY_M", "CACHE_KEY_N", "CACHE_KEY_K"], + prune_configs_by={ + "early_config_prune": early_config_prune, + "perf_model": estimate_matmul_time, + "top_k": 10, + }, +) +@triton.heuristics( + { + "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0, + } +) +@triton.jit +def kernel_fwd( + C, # Pointers to matrices + ACT_INPUT, + A, + B, + bias, + # Matrix dimensions + M, + N, + K, + CACHE_KEY_M, + CACHE_KEY_N, + CACHE_KEY_K, + # The stride variables represent how much to increase the ptr by when moving by 1 + # element in a particular dimension. E.g. stride_am is how much to increase a_ptr + # by to get the element one row down (A has M rows) + stride_cm, + # stride_cn, # Assume that stride_cn == 1 + stride_am, + stride_ak, + stride_bn, + stride_bk, + # Meta-parameters + BLOCK_M: tl.constexpr, + GROUP_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + # split k not used, not performant with activation, kept because early_config_prune is expecting it + SPLIT_K: tl.constexpr, + EVEN_K: tl.constexpr, + A_ROWMAJOR: tl.constexpr, + B_COLMAJOR: tl.constexpr, + BIAS: tl.constexpr, + SAVE_ACT_INPUT: tl.constexpr, + ACTIVATION: tl.constexpr, +): + + """ + Kernel for computing Out = activation(A x W + C) + - Input has shape (M, K) + - Weight has shape (K, N) + - Bias has shape (N,) + - Output has shape (M, N) + - ActInputs (optional) has shape (M, N) + 'ActInputs' optionally saves the A x W + C intermediate for backward computations + This kernel will consolidate over K + """ + + pid = tl.program_id(axis=0) + + grid_m = (M + BLOCK_M - 1) // BLOCK_M + grid_n = (N + BLOCK_N - 1) // BLOCK_N + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + + # now compute the block that each program will go through + # rm (resp. rn) denotes a range of indices + # for rows (resp. col) of C + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + # trick to avoid masking on M and N axis + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = tl.arange(0, BLOCK_K) + + if A_ROWMAJOR: + A = A + (ram[:, None] * stride_am + rk[None, :]) + else: + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + if B_COLMAJOR: + B = B + (rk[:, None] + rbn[None, :] * stride_bn) + else: + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + + for k in range(K, 0, -BLOCK_K): + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + a = tl.load(A, mask=rk[None, :] < k, other=0.0) + b = tl.load(B, mask=rk[:, None] < k, other=0.0) + acc += tl.dot(a, b) + + if A_ROWMAJOR: + A += BLOCK_K + else: + A += BLOCK_K * stride_ak + if B_COLMAJOR: + B += BLOCK_K + else: + B += BLOCK_K * stride_bk + + # Putting bias after the matmul (instead of before) is faster, idk why + if BIAS: + bias = tl.load(bias + rn, mask=rn < N, other=0.0).to(tl.float32) + acc += bias[None, :] + + # optional: save the activation inputs + if SAVE_ACT_INPUT: + # act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :] * stride_cn + act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :] + tl.store(act_in_ptrs, acc) + + # optional: fused activation (while the data is in shared memory) + if ACTIVATION == "gelu": + acc = gelu(acc) + elif ACTIVATION == "gelu_approx": + acc = gelu_approx(acc) + elif ACTIVATION == "squared_relu": + acc = squared_relu(acc) + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + # write back result + # C = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn + C = C + rm[:, None] * stride_cm + rn[None, :] + mask = (rm < M)[:, None] & (rn < N)[None, :] + tl.store(C, acc) + + +def triton_linear_act( + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + activation: str = "id", + save_act_input: bool = False, +) -> torch.Tensor: + """ + Compute e = activation(x @ weight.T + bias). + This wrapper kicks the `kernel_fwd` Triton kernel + :param x: input tensor + :param weight: weight matrix + :param bias: an optional bias tensor + :param activation: Activation name. Needs to be a Triton kernel. + :param act_input: an optional tensor to save the activation inputs (for backward) + :return: result tensor + """ + # if torch.is_autocast_enabled(): + # dtype = torch.get_autocast_gpu_dtype() + # x, weight, bias = [a.to(dtype=dtype) for a in [x, weight, bias]] + + assert activation in ["id", "gelu", "gelu_approx", "squared_relu"] + + batch_shape, n = x.shape[:-1], x.shape[-1] + batch_dim = batch_shape.numel() + x_reshaped = x.reshape(batch_dim, n) + + if x_reshaped.stride(0) > 1 and x_reshaped.stride(1) > 1: + x_reshaped = x_reshaped.contiguous() + if weight.stride(0) > 1 and weight.stride(1) > 1: + weight = weight.contiguous() + bias = bias.contiguous() if bias is not None else None + + assert ( + x.dtype == weight.dtype + ), f"Input and weight must have the same dtype, got {x.dtype} and {weight.dtype}" + if bias is not None: + assert ( + x.dtype == bias.dtype + ), f"Input and bias must have the same dtype, got {x.dtype} and {bias.dtype}" + assert ( + x_reshaped.shape[1] == weight.shape[1] + ), f"Incompatible dimensions: {x_reshaped.shape} - {weight.shape}" + + assert ( + bias is None or bias.shape[0] == weight.shape[0] + ), "Incompatible dimensions in between weight and bias" + + M, K = x_reshaped.shape + N, K = weight.shape + + output = torch.empty((M, N), device=x.device, dtype=x.dtype) + act_input = torch.empty_like(output) if save_act_input else None + + # 1D launch kernel where each block gets its own program. + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) # noqa + + kernel_fwd[grid]( + output, + act_input, + x_reshaped, + weight, # data ptrs + bias if bias is not None else x, # auto skip bias if not present + M, # shapes + N, + K, + M // 32, # key for triton cache (limit number of compilations) + N // 32, + K // 32, + stride_cm=output.stride(0), # strides + # stride_cn=output.stride(1), + stride_am=x_reshaped.stride(0), + stride_ak=x_reshaped.stride(1), + stride_bk=weight.stride(1), + stride_bn=weight.stride(0), + BIAS=bias is not None, # optional fused bias + SAVE_ACT_INPUT=save_act_input, # optional save activation inputs + ACTIVATION=activation, # optional fused activation + A_ROWMAJOR=x_reshaped.stride(1) == 1, + B_COLMAJOR=weight.stride(1) == 1, + GROUP_M=8, # speed optimization: group the programs + ) + + if not save_act_input: + return output.reshape(*batch_shape, output.shape[-1]) + else: + return ( + output.reshape(*batch_shape, output.shape[-1]), + act_input.reshape(*batch_shape, act_input.shape[-1]), + ) + + +@triton.autotune( + configs=[ + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8 + ), + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8 + ), + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2 + ), + # good for int8 + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, + num_stages=3, + num_warps=8, + ), + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, + num_stages=3, + num_warps=8, + ), + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2 + ), + ] + + get_configs_io_bound(), + key=["CACHE_KEY_M", "CACHE_KEY_N", "CACHE_KEY_K"], + prune_configs_by={ + "early_config_prune": early_config_prune, + "perf_model": estimate_matmul_time, + "top_k": 10, + }, +) +@triton.heuristics( + { + "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0, + } +) +@triton.jit +def kernel_bwd( + C, # Pointers to matrices + ACT_INPUT, + A, + B, + # Matrix dimensions + M, + N, + K, + CACHE_KEY_M, + CACHE_KEY_N, + CACHE_KEY_K, + # The stride variables represent how much to increase the ptr by when moving by 1 + # element in a particular dimension. E.g. stride_am is how much to increase a_ptr + # by to get the element one row down (A has M rows) + stride_cm, + # stride_cn, # Assume that stride_cn == 1 + stride_am, + stride_ak, + stride_bk, + stride_bn, + # Meta-parameters + BLOCK_M: tl.constexpr, + GROUP_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + # split k not used, not performant with activation, kept because early_config_prune is expecting it + SPLIT_K: tl.constexpr, + EVEN_K: tl.constexpr, + ACTIVATION: tl.constexpr, +): + + """ + Kernel for computing Out = activation(A x W + C) + - Input has shape (M, K) + - Weight has shape (K, N) + - Output has shape (M, N) + - ActInputs (optional) has shape (M, N) + 'ActInputs' optionally saves the A x W + C intermediate for backward computations + This kernel will consolidate over K + """ + + pid = tl.program_id(axis=0) + + grid_m = (M + BLOCK_M - 1) // BLOCK_M + grid_n = (N + BLOCK_N - 1) // BLOCK_N + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + + # now compute the block that each program will go through + # rm (resp. rn) denotes a range of indices + # for rows (resp. col) of C + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + # trick to avoid masking on M and N axis + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = tl.arange(0, BLOCK_K) + + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + + for k in range(K, 0, -BLOCK_K): + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + a = tl.load(A, mask=rk[None, :] < k, other=0.0) + b = tl.load(B, mask=rk[:, None] < k, other=0.0) + acc += tl.dot(a, b) + + A += BLOCK_K * stride_ak + B += BLOCK_K * stride_bk + + # optional: fused activation (while the data is in shared memory) + if ACTIVATION != "id": + act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :] + act_input = tl.load(act_in_ptrs).to(acc.dtype) + if ACTIVATION == "gelu": + acc *= gelu_grad(act_input) + elif ACTIVATION == "gelu_approx": + acc *= gelu_approx_grad(act_input) + elif ACTIVATION == "squared_relu": + acc *= squared_relu_grad(act_input) + + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + # write back result + C = C + rm[:, None] * stride_cm + rn[None, :] + mask = (rm < M)[:, None] & (rn < N)[None, :] + tl.store(C, acc, mask=mask) + + +def triton_dgrad_act( + grad_output: torch.Tensor, + weight: torch.Tensor, + activation: str = "id", + act_input: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + Compute e = activation(grad_output @ weight + bias). + This wrapper kicks the `kernel_fwd` Triton kernel + :param grad_output: input tensor + :param weight: weight matrix + :param activation: Activation name. Needs to be a Triton kernel. + :param act_input: an optional tensor to save the activation inputs (for backward) + :return: result tensor + """ + assert activation in ["id", "gelu", "gelu_approx", "squared_relu"] + + batch_shape, n = grad_output.shape[:-1], grad_output.shape[-1] + batch_dim = batch_shape.numel() + grad_output_reshaped = grad_output.reshape(batch_dim, n) + + if grad_output_reshaped.stride(0) > 1 and grad_output_reshaped.stride(1) > 1: + grad_output_reshaped = grad_output_reshaped.contiguous() + if weight.stride(0) > 1 and weight.stride(1) > 1: + weight = weight.contiguous() + + assert ( + grad_output.dtype == weight.dtype + ), f"grad_output and weight must have the same dtype, got {grad_output.dtype} and {weight.dtype}" + assert ( + grad_output_reshaped.shape[1] == weight.shape[0] + ), f"Incompatible dimensions: {grad_output_reshaped.shape} - {weight.shape}" + if activation != "id": + assert act_input is not None, f"act_input is required for activation {activation}" + + # M, N, K in bwd are different from M, N, K in fwd + M, K = grad_output_reshaped.shape + K, N = weight.shape + + grad_input = torch.empty((M, N), device=grad_output.device, dtype=grad_output.dtype) + + # 1D launch kernel where each block gets its own program. + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) # noqa + + kernel_bwd[grid]( + grad_input, + act_input, + grad_output_reshaped, + weight, # data ptrs + M, # shapes + N, + K, + M // 32, # key for triton cache (limit number of compilations) + N // 32, + K // 32, + stride_cm=grad_input.stride(0), # strides + # stride_cn=grad_input.stride(1), + stride_am=grad_output_reshaped.stride(0), + stride_ak=grad_output_reshaped.stride(1), + stride_bk=weight.stride(0), + stride_bn=weight.stride(1), + ACTIVATION=activation, # optional fused activation + GROUP_M=8, # speed optimization: group the programs + ) + + return grad_input.reshape(*batch_shape, grad_input.shape[-1]) diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/triton/mlp.py b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/triton/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..b795310f1c8afc8203124597bb6ca70f1af7ed11 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/triton/mlp.py @@ -0,0 +1,149 @@ +# The triton fused matmul + sqrelu is faster for fp16 but slower for bf16, compared +# to naive implementation. +import fused_dense_lib as fused_dense_cuda +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.cuda.amp import custom_bwd, custom_fwd + +from flash_attn.ops.activations import sqrelu_bwd, sqrelu_fwd +from flash_attn.ops.triton.linear import triton_dgrad_act, triton_linear_act + + +class FusedDenseSqreluDenseFunc(torch.autograd.Function): + @staticmethod + @custom_fwd + def forward(ctx, x, weight1, bias1, weight2, bias2, checkpoint_lvl=0): + """checkpoint_lvl: + 0: no recomputation in the bwd + 1: recompute gelu_out in the bwd + 2: recompute act_input and gelu_out in the bwd + """ + if torch.is_autocast_enabled(): + dtype = torch.get_autocast_gpu_dtype() + x, weight1, bias1, weight2, bias2 = [ + a.to(dtype=dtype) for a in [x, weight1, bias1, weight2, bias2] + ] + is_bf16 = x.dtype == torch.bfloat16 + assert checkpoint_lvl in [0, 1, 2] + x = x.contiguous() + weight1 = weight1.contiguous() + bias1 = bias1.contiguous() + weight2 = weight2.contiguous() + bias2 = bias2.contiguous() + batch_shape, n = x.shape[:-1], x.shape[-1] + batch_dim = batch_shape.numel() + if is_bf16: + act_input = fused_dense_cuda.linear_bias_forward( + x.reshape(batch_dim, n), weight1, bias1 + ) + output1 = sqrelu_fwd(act_input) + else: + save_act_input = checkpoint_lvl != 2 + result = triton_linear_act( + x.reshape(batch_dim, n), + weight1, + bias1, + activation="squared_relu", + save_act_input=save_act_input, + ) + if save_act_input: + output1, act_input = result + else: + output1 = result + output2 = fused_dense_cuda.linear_bias_forward(output1, weight2, bias2) + ctx.checkpoint_lvl = checkpoint_lvl + if checkpoint_lvl == 0: + ctx.save_for_backward(x, weight1, bias1, weight2, act_input, output1) + elif checkpoint_lvl == 1: + ctx.save_for_backward(x, weight1, bias1, weight2, act_input) + elif checkpoint_lvl == 2: + ctx.save_for_backward(x, weight1, bias1, weight2) + return output2.reshape(*batch_shape, output2.shape[-1]) + + @staticmethod + @custom_bwd + def backward(ctx, grad_output): + grad_output = grad_output.contiguous() + checkpoint_lvl = ctx.checkpoint_lvl + x, weight1, bias1, weight2, *rest = ctx.saved_tensors + batch_shape, n = x.shape[:-1], x.shape[-1] + batch_dim = batch_shape.numel() + is_bf16 = x.dtype == torch.bfloat16 + if checkpoint_lvl == 0: + act_input, output1 = rest + elif checkpoint_lvl == 1: + (act_input,) = rest + output1 = sqrelu_fwd(act_input) + elif checkpoint_lvl == 2: + if is_bf16: + act_input = fused_dense_cuda.linear_bias_forward( + x.reshape(batch_dim, n), weight1, bias1 + ) + output1 = sqrelu_fwd(act_input) + else: + output1, act_input = triton_linear_act( + x.reshape(batch_dim, n), + weight1, + bias1, + activation="squared_relu", + save_act_input=True, + ) + + if is_bf16: + grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) + grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(output1, grad_output) + grad_output1 = grad_output @ weight2 + grad_act_input = sqrelu_bwd(grad_output1, act_input) + grad_input, grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_backward( + x.reshape(batch_dim, n), weight1, grad_act_input + ) + else: + grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) + grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(output1, grad_output) + grad_act_input = triton_dgrad_act( + grad_output, weight2, activation="squared_relu", act_input=act_input + ) + grad_input, grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_backward( + x.reshape(batch_dim, n), weight1, grad_act_input + ) + return grad_input.reshape_as(x), grad_weight1, grad_bias1, grad_weight2, grad_bias2, None + + +fused_dense_sqrelu_dense_function = FusedDenseSqreluDenseFunc.apply + + +class FusedDenseSqreluDense(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + bias1=True, + bias2=True, + checkpoint_lvl=0, + device=None, + dtype=None, + ): + """ + checkpoint_lvl (increasing lvl means slower but more memory saving): + 0: no recomputation in the bwd + 1: recompute gelu_out in the bwd + 2: recompute gelu_in and gelu_out in the bwd + """ + assert checkpoint_lvl in [0, 1, 2] + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features * 4 + assert bias1 == True, "DenseSqreluDense module without bias is currently not supported" + assert bias2 == True, "DenseSqreluDense module without bias is currently not supported" + self.checkpoint_lvl = checkpoint_lvl + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs) + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs) + + def forward(self, x): + assert x.is_cuda + return fused_dense_sqrelu_dense_function( + x, self.fc1.weight, self.fc1.bias, self.fc2.weight, self.fc2.bias, self.checkpoint_lvl + ) diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/triton/rotary.py b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/triton/rotary.py new file mode 100644 index 0000000000000000000000000000000000000000..6c04a523ede814ea075e6773572cb56cac8bff64 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/triton/rotary.py @@ -0,0 +1,227 @@ +# Copyright (c) 2023, Tri Dao. + +from typing import Optional, Union + +import torch + +import triton +import triton.language as tl + + +@triton.jit +def rotary_kernel( + OUT, # Pointers to matrices + X, + COS, + SIN, + CU_SEQLENS, + SEQLEN_OFFSETS, # this could be int or a pointer + # Matrix dimensions + seqlen, + rotary_dim, + seqlen_ro, + # strides + stride_out_batch, + stride_out_seqlen, + stride_out_nheads, + stride_out_headdim, + stride_x_batch, + stride_x_seqlen, + stride_x_nheads, + stride_x_headdim, + # Meta-parameters + BLOCK_K: tl.constexpr, + IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr, + IS_VARLEN: tl.constexpr, + INTERLEAVED: tl.constexpr, + CONJUGATE: tl.constexpr, + BLOCK_M: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_batch = tl.program_id(axis=1) + pid_head = tl.program_id(axis=2) + rotary_dim_half = rotary_dim // 2 + + if not IS_VARLEN: + X = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads + OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads + else: + start_idx = tl.load(CU_SEQLENS + pid_batch) + seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx + X = X + start_idx * stride_x_seqlen + pid_head * stride_x_nheads + OUT = OUT + start_idx * stride_out_seqlen + pid_head * stride_out_nheads + + if pid_m * BLOCK_M >= seqlen: + return + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + if not IS_SEQLEN_OFFSETS_TENSOR: + rm_cs = rm + SEQLEN_OFFSETS + else: + rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch) + rk = tl.arange(0, BLOCK_K) + rk_half = tl.arange(0, BLOCK_K // 2) + + if not INTERLEAVED: + # Load the 1st and 2nd halves of X, do calculation, then store to 1st and 2nd halves of OUT + X = X + (rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim) + COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :]) + SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :]) + cos = tl.load( + COS, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=1.0 + ).to(tl.float32) + sin = tl.load( + SIN, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=0.0 + ).to(tl.float32) + x0 = tl.load( + X, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), other=0.0 + ).to(tl.float32) + x1 = tl.load( + X + rotary_dim_half * stride_x_headdim, + mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), + other=0.0, + ).to(tl.float32) + if CONJUGATE: + sin = -sin + o0 = x0 * cos - x1 * sin + o1 = x0 * sin + x1 * cos + # write back result + OUT = OUT + (rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim) + tl.store(OUT, o0, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half)) + tl.store( + OUT + rotary_dim_half * stride_out_headdim, + o1, + mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), + ) + else: + # We don't want to load X[0, 2, 4, ...] and X[1, 3, 5, ...] separately since both are slow. + # Instead, we load x0 = X[0, 1, 2, 3, ...] and x1 = X[1, 0, 3, 2, ...]. + # Loading x0 will be fast but x1 will be slow. + # Then we load cos = COS[0, 0, 1, 1, ...] and sin = SIN[0, 0, 1, 1, ...]. + # Then we do the calculation and use tl.where to pick put the right outputs for the even + # and for the odd indices. + rk_swap = rk + ((rk + 1) % 2) * 2 - 1 # 1, 0, 3, 2, 5, 4, ... + rk_repeat = tl.arange(0, BLOCK_K) // 2 + X0 = X + (rm[:, None] * stride_x_seqlen + rk[None, :] * stride_x_headdim) + X1 = X + (rm[:, None] * stride_x_seqlen + rk_swap[None, :] * stride_x_headdim) + COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :]) + SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :]) + cos = tl.load( + COS, + mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half), + other=1.0, + ).to(tl.float32) + sin = tl.load( + SIN, + mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half), + other=0.0, + ).to(tl.float32) + x0 = tl.load(X0, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim), other=0.0).to( + tl.float32 + ) + x1 = tl.load( + X1, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0 + ).to(tl.float32) + if CONJUGATE: + sin = -sin + x0_cos = x0 * cos + x1_sin = x1 * sin + out = tl.where(rk[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin) + OUT = OUT + (rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim) + tl.store(OUT, out, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim)) + + +def apply_rotary( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + interleaved=False, + inplace=False, + conjugate=False, +) -> torch.Tensor: + """ + Arguments: + x: (batch, seqlen, nheads, headdim) if cu_seqlens is None + else (total_seqlen, nheads, headdim). + cos: (seqlen_ro, rotary_dim / 2) + sin: (seqlen_ro, rotary_dim / 2) + seqlen_offsets: integer or integer tensor of size (batch,) + cu_seqlens: (batch + 1,) or None + max_seqlen: int + Returns: + y: (batch, seqlen, nheads, headdim) + """ + is_varlen = cu_seqlens is not None + if not is_varlen: + batch, seqlen, nheads, headdim = x.shape + else: + assert max_seqlen is not None, "If cu_seqlens is passed in, then max_seqlen must be passed" + total_seqlen, nheads, headdim = x.shape + batch_p_1 = cu_seqlens.shape[0] + batch = batch_p_1 - 1 + seqlen = max_seqlen + seqlen_ro, rotary_dim = cos.shape + assert sin.shape == cos.shape + rotary_dim *= 2 + assert rotary_dim <= headdim, "rotary_dim must be <= headdim" + assert headdim <= 256, "Only support headdim <= 256" + assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen" + + assert ( + cos.dtype == sin.dtype + ), f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}" + assert ( + x.dtype == cos.dtype + ), f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}" + + cos, sin = cos.contiguous(), sin.contiguous() + if isinstance(seqlen_offsets, torch.Tensor): + assert seqlen_offsets.shape == (batch,) + assert seqlen_offsets.dtype in [torch.int32, torch.int64] + seqlen_offsets = seqlen_offsets.contiguous() + else: + assert seqlen_offsets + seqlen <= seqlen_ro + + output = torch.empty_like(x) if not inplace else x + if rotary_dim < headdim and not inplace: + output[..., rotary_dim:].copy_(x[..., rotary_dim:]) + + BLOCK_K = ( + 32 + if rotary_dim <= 32 + else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256)) + ) + grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_M"]), batch, nheads) # noqa + BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4) + + # Need this, otherwise Triton tries to launch from cuda:0 and we get + # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?) + with torch.cuda.device(x.device.index): + rotary_kernel[grid]( + output, # data ptrs + x, + cos, + sin, + cu_seqlens, + seqlen_offsets, + seqlen, # shapes + rotary_dim, + seqlen_ro, + output.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0 + output.stride(-3), # seqlen_stride or total_seqlen_stride + output.stride(-2), # nheads_stride + output.stride(-1), # headdim_stride + x.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0 + x.stride(-3), # seqlen stride or total_seqlen_stride + x.stride(-2), # nheads stride + x.stride(-1), # headdim stride + BLOCK_K, + isinstance(seqlen_offsets, torch.Tensor), + is_varlen, + interleaved, + conjugate, + BLOCK_M, + ) + return output diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/utils/__init__.py b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/utils/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/utils/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe7b155fd2d2770b8972b22db581f052cd3b8a6f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/utils/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/utils/__pycache__/benchmark.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/utils/__pycache__/benchmark.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..048c5404bf114704c4e2999c515435c5a2cf8ee0 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/utils/__pycache__/benchmark.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/utils/__pycache__/distributed.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/utils/__pycache__/distributed.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8fee516e49558c00a19e5556330e774a0fc945a5 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/utils/__pycache__/distributed.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/utils/__pycache__/generation.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/utils/__pycache__/generation.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f98df6a4a49376cbc4db50137c5a98cc48be6ca3 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/utils/__pycache__/generation.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/utils/__pycache__/pretrained.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/utils/__pycache__/pretrained.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2dfa85fdc27d4fbfc8a48f2e76718ea2e810be0 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/utils/__pycache__/pretrained.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/utils/benchmark.py b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/utils/benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..15b30405f209921189b75f7307814876350e7317 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/utils/benchmark.py @@ -0,0 +1,268 @@ +# Copyright (c) 2023, Tri Dao. +""" Useful functions for writing test code. """ + +import torch +import torch.utils.benchmark as benchmark + + +def benchmark_forward( + fn, *inputs, repeats=10, desc="", verbose=True, amp=False, amp_dtype=torch.float16, **kwinputs +): + """Use Pytorch Benchmark on the forward pass of an arbitrary function.""" + if verbose: + print(desc, "- Forward pass") + + def amp_wrapper(*inputs, **kwinputs): + with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): + fn(*inputs, **kwinputs) + + t = benchmark.Timer( + stmt="fn_amp(*inputs, **kwinputs)", + globals={"fn_amp": amp_wrapper, "inputs": inputs, "kwinputs": kwinputs}, + num_threads=torch.get_num_threads(), + ) + m = t.timeit(repeats) + if verbose: + print(m) + return t, m + + +def benchmark_backward( + fn, + *inputs, + grad=None, + repeats=10, + desc="", + verbose=True, + amp=False, + amp_dtype=torch.float16, + **kwinputs, +): + """Use Pytorch Benchmark on the backward pass of an arbitrary function.""" + if verbose: + print(desc, "- Backward pass") + with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): + y = fn(*inputs, **kwinputs) + if type(y) is tuple: + y = y[0] + if grad is None: + grad = torch.randn_like(y) + else: + if grad.shape != y.shape: + raise RuntimeError("Grad shape does not match output shape") + + def f(*inputs, y, grad): + # Set .grad to None to avoid extra operation of gradient accumulation + for x in inputs: + if isinstance(x, torch.Tensor): + x.grad = None + y.backward(grad, retain_graph=True) + + t = benchmark.Timer( + stmt="f(*inputs, y=y, grad=grad)", + globals={"f": f, "inputs": inputs, "y": y, "grad": grad}, + num_threads=torch.get_num_threads(), + ) + m = t.timeit(repeats) + if verbose: + print(m) + return t, m + + +def benchmark_combined( + fn, + *inputs, + grad=None, + repeats=10, + desc="", + verbose=True, + amp=False, + amp_dtype=torch.float16, + **kwinputs, +): + """Use Pytorch Benchmark on the forward+backward pass of an arbitrary function.""" + if verbose: + print(desc, "- Forward + Backward pass") + with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): + y = fn(*inputs, **kwinputs) + if type(y) is tuple: + y = y[0] + if grad is None: + grad = torch.randn_like(y) + else: + if grad.shape != y.shape: + raise RuntimeError("Grad shape does not match output shape") + + def f(grad, *inputs, **kwinputs): + for x in inputs: + if isinstance(x, torch.Tensor): + x.grad = None + with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): + y = fn(*inputs, **kwinputs) + if type(y) is tuple: + y = y[0] + y.backward(grad, retain_graph=True) + + t = benchmark.Timer( + stmt="f(grad, *inputs, **kwinputs)", + globals={"f": f, "fn": fn, "inputs": inputs, "grad": grad, "kwinputs": kwinputs}, + num_threads=torch.get_num_threads(), + ) + m = t.timeit(repeats) + if verbose: + print(m) + return t, m + + +def benchmark_fwd_bwd( + fn, + *inputs, + grad=None, + repeats=10, + desc="", + verbose=True, + amp=False, + amp_dtype=torch.float16, + **kwinputs, +): + """Use Pytorch Benchmark on the forward+backward pass of an arbitrary function.""" + return ( + benchmark_forward( + fn, + *inputs, + repeats=repeats, + desc=desc, + verbose=verbose, + amp=amp, + amp_dtype=amp_dtype, + **kwinputs, + ), + benchmark_backward( + fn, + *inputs, + grad=grad, + repeats=repeats, + desc=desc, + verbose=verbose, + amp=amp, + amp_dtype=amp_dtype, + **kwinputs, + ), + ) + + +def benchmark_all( + fn, + *inputs, + grad=None, + repeats=10, + desc="", + verbose=True, + amp=False, + amp_dtype=torch.float16, + **kwinputs, +): + """Use Pytorch Benchmark on the forward+backward pass of an arbitrary function.""" + return ( + benchmark_forward( + fn, + *inputs, + repeats=repeats, + desc=desc, + verbose=verbose, + amp=amp, + amp_dtype=amp_dtype, + **kwinputs, + ), + benchmark_backward( + fn, + *inputs, + grad=grad, + repeats=repeats, + desc=desc, + verbose=verbose, + amp=amp, + amp_dtype=amp_dtype, + **kwinputs, + ), + benchmark_combined( + fn, + *inputs, + grad=grad, + repeats=repeats, + desc=desc, + verbose=verbose, + amp=amp, + amp_dtype=amp_dtype, + **kwinputs, + ), + ) + + +def pytorch_profiler( + fn, + *inputs, + trace_filename=None, + backward=False, + amp=False, + amp_dtype=torch.float16, + cpu=False, + verbose=True, + **kwinputs, +): + """Wrap benchmark functions in Pytorch profiler to see CUDA information.""" + if backward: + with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): + out = fn(*inputs, **kwinputs) + if type(out) is tuple: + out = out[0] + g = torch.randn_like(out) + for _ in range(30): # Warm up + if backward: + for x in inputs: + if isinstance(x, torch.Tensor): + x.grad = None + with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): + out = fn(*inputs, **kwinputs) + if type(out) is tuple: + out = out[0] + # Backward should be done outside autocast + if backward: + out.backward(g, retain_graph=True) + activities = ([torch.profiler.ProfilerActivity.CPU] if cpu else []) + [ + torch.profiler.ProfilerActivity.CUDA + ] + with torch.profiler.profile( + activities=activities, + record_shapes=True, + # profile_memory=True, + with_stack=True, + ) as prof: + if backward: + for x in inputs: + if isinstance(x, torch.Tensor): + x.grad = None + with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): + out = fn(*inputs, **kwinputs) + if type(out) is tuple: + out = out[0] + if backward: + out.backward(g, retain_graph=True) + if verbose: + # print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=50)) + print(prof.key_averages().table(row_limit=50)) + if trace_filename is not None: + prof.export_chrome_trace(trace_filename) + + +def benchmark_memory(fn, *inputs, desc="", verbose=True, **kwinputs): + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + torch.cuda.synchronize() + fn(*inputs, **kwinputs) + torch.cuda.synchronize() + mem = torch.cuda.max_memory_allocated() / ((2**20) * 1000) + if verbose: + print(f"{desc} max memory: {mem}GB") + torch.cuda.empty_cache() + return mem diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/utils/distributed.py b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/utils/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..74c55279645cd0fd687584bc1b7374c8c3c73e56 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/utils/distributed.py @@ -0,0 +1,144 @@ +from typing import Optional + +import torch +from torch import Tensor +from torch.distributed import ProcessGroup + +# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for +# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent +# version of PyTorch. The following 4 lines are for backward compatibility with +# older PyTorch. +if "all_gather_into_tensor" not in dir(torch.distributed): + torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base +if "reduce_scatter_tensor" not in dir(torch.distributed): + torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base + + +# Raw operation, does not support autograd, but does support async +def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): + world_size = torch.distributed.get_world_size(process_group) + output = torch.empty( + world_size * input_.shape[0], *input_.shape[1:], dtype=input_.dtype, device=input_.device + ) + handle = torch.distributed.all_gather_into_tensor( + output, input_.contiguous(), group=process_group, async_op=async_op + ) + return output, handle + + +# Raw operation, does not support autograd, but does support async +def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): + world_size = torch.distributed.get_world_size(process_group) + assert input_.shape[0] % world_size == 0 + output = torch.empty( + input_.shape[0] // world_size, *input_.shape[1:], dtype=input_.dtype, device=input_.device + ) + handle = torch.distributed.reduce_scatter_tensor( + output, input_.contiguous(), group=process_group, async_op=async_op + ) + return output, handle + + +# Raw operation, does not support autograd, but does support async +def all_reduce_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): + input_ = input_.contiguous() + handle = torch.distributed.all_reduce(input_, group=process_group, async_op=async_op) + return input_, handle + + +class AllGatherFunc(torch.autograd.Function): + """Gather the input from sequence parallel region and concatenate.""" + + @staticmethod + def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor: + ctx.process_group = process_group + output, _ = all_gather_raw(input_, process_group) + return output + + @staticmethod + def backward(ctx, grad_output: Tensor): + grad_input, _ = reduce_scatter_raw(grad_output, ctx.process_group) + return grad_input, None + + +# Supports autograd, but does not support async +all_gather = AllGatherFunc.apply + + +class ReduceScatterFunc(torch.autograd.Function): + """Reduce scatter the input from the sequence parallel region and concatenate.""" + + @staticmethod + def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor: + ctx.process_group = process_group + output, _ = reduce_scatter_raw(input_, process_group) + return output + + @staticmethod + def backward(ctx, grad_output: Tensor): + grad_input, _ = all_gather_raw(grad_output, ctx.process_group) + return grad_input, None + + +# Supports autograd, but does not support async +reduce_scatter = ReduceScatterFunc.apply + + +class AllReduceFunc(torch.autograd.Function): + """Gather the input from sequence parallel region and concatenate.""" + + @staticmethod + def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor: + ctx.process_group = process_group + output, _ = all_reduce_raw(input_, process_group) + return output + + @staticmethod + def backward(ctx, grad_output: Tensor): + return grad_output, None + + +# Supports autograd, but does not support async +all_reduce = AllReduceFunc.apply + + +def sync_shared_params(model: torch.nn.Module, process_group: ProcessGroup): + # We want to iterate over parameters with _shared_params=True in the same order, + # as different ranks might have different number of parameters (e.g., only rank 0 has bias). + pamams_shared = { + name: p for name, p in model.named_parameters() if getattr(p, "_shared_params", False) + } + for _, p in sorted(pamams_shared.items()): + with torch.no_grad(): + # Broadcast needs src to be global rank, not group rank + torch.distributed.broadcast( + p, src=torch.distributed.get_global_rank(process_group, 0), group=process_group + ) + + +# Ref: https://github.com/NVIDIA/Megatron-LM/blob/52e636888cccc41e931251c417a7181fc36de926/megatron/optimizer/optimizer.py#L256 +def allreduce_sequence_parallel_grad(model: torch.nn.Module, process_group: ProcessGroup): + # We want to iterate over parameters with _sequence_parallel=True in the same order, + # as different ranks might have different number of parameters (e.g., only rank 0 has bias). + params_seqparallel = { + name: p for name, p in model.named_parameters() if getattr(p, "_sequence_parallel", False) + } + grads = [p.grad for _, p in sorted(params_seqparallel.items())] + if grads: + with torch.no_grad(): + coalesced = torch._utils._flatten_dense_tensors(grads) + torch.distributed.all_reduce(coalesced, group=process_group) + for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)): + buf.copy_(synced) + + +def get_dim_for_local_rank(dim: int, world_size: int, local_rank: int, multiple_of: int = 1) -> int: + """Get the dim for the local rank derived from splitting dim on world_size processes. + + The split may not be even across the world_size processes. + """ + multiple = dim // multiple_of + div = multiple // world_size + mod = multiple % world_size + local_multiple = div + int(local_rank < mod) + return local_multiple * multiple_of diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/utils/generation.py b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/utils/generation.py new file mode 100644 index 0000000000000000000000000000000000000000..0d9120c386596f25b544391af10dc479cf00c822 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/utils/generation.py @@ -0,0 +1,740 @@ +# Copyright (c) 2023, Tri Dao. +# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/forward_step.py#L31 +import gc +import time +from collections import namedtuple +from dataclasses import dataclass, field +from functools import partial +from typing import Callable, Optional, Sequence, Union + +import torch +import torch.nn.functional as F +from einops import rearrange, repeat +from torch import Tensor +from torch.profiler import ProfilerActivity, profile, record_function + +try: + from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput +except ImportError: + GreedySearchDecoderOnlyOutput = namedtuple("GreedySearchDecoderOnlyOutput", ["sequences", "scores"]) + SampleDecoderOnlyOutput = namedtuple("SampleDecoderOnlyOutput", ["sequences", "scores"]) + + +@dataclass +class InferenceParams: + """Inference parameters that are passed to the main model in order + to efficienly calculate and store the context during inference.""" + + max_seqlen: int + max_batch_size: int + seqlen_offset: int = 0 + batch_size_offset: int = 0 + key_value_memory_dict: dict = field(default_factory=dict) + lengths_per_sample: Optional[Tensor] = None + + def reset(self, max_seqlen, max_batch_size): + self.max_seqlen = max_seqlen + self.max_batch_size = max_batch_size + self.seqlen_offset = 0 + if self.lengths_per_sample is not None: + self.lengths_per_sample.zero_() + + +# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py +# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231 +def modify_logits_for_top_k_filtering(logits, top_k): + """Set the logits for none top-k values to -inf. Done in-place.""" + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits.masked_fill_(indices_to_remove, float("-Inf")) + + +# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py +# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170 +def modify_logits_for_top_p_filtering(logits, top_p): + """Set the logits for none top-p values to -inf. Done in-place.""" + if top_p <= 0.0 or top_p >= 1.0: + return + # First sort and calculate cumulative sum of probabilities. + sorted_logits, sorted_indices = torch.sort(logits, descending=False) + cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) + # Remove tokens with cumulative top_p above the threshold (token with 0 are kept) + sorted_indices_to_remove = cumulative_probs <= (1 - top_p) + # scatter sorted tensors to original indexing + indices_to_remove = sorted_indices_to_remove.scatter( + 1, sorted_indices, sorted_indices_to_remove + ) + logits.masked_fill_(indices_to_remove, float("-inf")) + + +def sample(logits, top_k=1, top_p=0.0, temperature=1.0): + """Sample from top-k logits. + Arguments: + logits: Tensor of shape (batch_size, vocab_size) + """ + if top_k == 1: # Short-circuit for greedy decoding + return logits.argmax(dim=-1) + else: + if top_p > 0.0: + assert top_p <= 1.0, "top-p should be in (0, 1]." + if top_k > 0: + top_k = min(top_k, logits.size(-1)) # Safety check + logits_top, indices = torch.topk(logits, top_k, dim=-1) + if temperature != 1.0: + logits_top /= temperature + modify_logits_for_top_p_filtering(logits_top, top_p) + return indices[ + torch.arange(indices.shape[0], device=indices.device), + torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1), + ] + else: + # Clone so that when we modify for top_p we don't change the original logits + logits_top = logits / temperature if temperature != 1.0 else logits.clone() + modify_logits_for_top_p_filtering(logits_top, top_p) + return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze( + dim=-1 + ) + + +@torch.inference_mode() +def decode( + input_ids, + model, + max_length, + top_k=1, + top_p=0.0, + temperature=1.0, + eos_token_id=None, + teacher_outputs=None, + vocab_size=None, + tensor_parallel=1, + cg=False, + enable_timing=False, +): + """Decoding, either greedy or with top-k or top-p sampling. + If top-k = 0, don't limit the number of candidates (pure sampling). + Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first, + then top-p. + We assume that all sequences in the same batch have the same length. + + Arguments: + input_ids: (batch, seq_len) + max_length: int + teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the + logits, the next token is taken from the teacher_outputs. Useful for testing. + Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields: + sequences: (batch, max_length) + scores: tuples of (batch, vocab_size) + """ + batch_size, seqlen_og = input_ids.shape + teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0 + if cg: + if not hasattr(model, "_decoding_cache"): + model._decoding_cache = None + model._decoding_cache = update_graph_cache( + model, + model._decoding_cache, + batch_size, + seqlen_og, + max_length, + tensor_parallel=tensor_parallel, + ) + inference_params = model._decoding_cache.inference_params + inference_params.reset(max_length, batch_size) + else: + inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size) + + def get_logits(input_ids, inference_params): + decoding = inference_params.seqlen_offset > 0 + if decoding: + position_ids = torch.full( + (batch_size, 1), + inference_params.seqlen_offset, + dtype=torch.long, + device=input_ids.device, + ) + else: + position_ids = None + if not cg or not decoding: + logits = model( + input_ids, + position_ids=position_ids, + inference_params=inference_params, + num_last_tokens=1, + ).logits.squeeze(dim=1) + else: + logits = model._decoding_cache.run( + input_ids, position_ids, inference_params.seqlen_offset + ).squeeze(dim=1) + return logits[..., :vocab_size] if vocab_size is not None else logits + + def sample_tokens(logits, inference_params): + if teacher_outputs is None or teacher_output_len <= inference_params.seqlen_offset: + token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature) + else: + token = teacher_outputs[:, inference_params.seqlen_offset] + # return rearrange(token, "b -> b 1") + return token.unsqueeze(1) + + def should_stop(current_token, inference_params): + if inference_params.seqlen_offset == 0: + return False + if eos_token_id is not None and (current_token == eos_token_id).all(): + return True + if inference_params.seqlen_offset >= max_length - 1: + return True + return False + + start = torch.cuda.Event(enable_timing=enable_timing) + end = torch.cuda.Event(enable_timing=enable_timing) + + if enable_timing: + if tensor_parallel > 1: + torch.distributed.barrier() + start.record() + scores, sequences = [], [input_ids] + while not should_stop(sequences[-1], inference_params): + scores.append(get_logits(sequences[-1], inference_params)) + inference_params.seqlen_offset += sequences[-1].shape[1] + sequences.append(sample_tokens(scores[-1], inference_params)) + if enable_timing: + end.record() + if tensor_parallel > 1: + torch.distributed.barrier() + torch.cuda.synchronize() + print(f"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms") + output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput + return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores)) + + +def sample_speculative(logits, logits_draft, tokens_draft, top_k=1, top_p=0.0, temperature=1.0): + """Algorithm 1 from [1] + [1] Fast Inference from Transformers via Speculative Decoding + Yaniv Leviathan, Matan Kalman, Yossi Matias + https://arxiv.org/abs/2211.17192 + + Arguments: + logits: Tensor of shape (batch_size, seqlen + 1, vocab_size) + logits_draft: Tensor of shape (batch_size, seqlen, vocab_size) + tokens_draft: Tensor of shape (batch_size, seqlen) + Return: + tokens: Tensor of shape (batch_size, seqlen + 1) + num_generated_tokens: Tensor of shape (batch_size), with value in [1, seqlen + 1]. + For each sequence in the batch, the number of valid tokens that were sampled by + speculative sampling. + """ + batch, seqlen_p_1, vocab_size = logits.shape + seqlen = seqlen_p_1 - 1 + assert logits_draft.shape == (batch, seqlen, vocab_size) + assert tokens_draft.shape == (batch, seqlen) + assert tokens_draft.dtype in [torch.int64, torch.int32] + # TODO: if top_k = 1 we can simplify things and only work with indices + if top_p > 0.0: + assert top_p <= 1.0, "top-p should be in (0, 1]." + # Clone so that when we modify for top_p we don't change the original logits + logits = logits / temperature if temperature != 1.0 else logits.clone() + logits_draft = logits_draft / temperature if temperature != 1.0 else logits_draft.clone() + if top_k > 0: + top_k = min(top_k, logits.size(-1)) # Safety check + modify_logits_for_top_k_filtering(logits, top_k) + modify_logits_for_top_k_filtering(logits_draft, top_k) + modify_logits_for_top_p_filtering(logits, top_p) + modify_logits_for_top_p_filtering(logits_draft, top_p) + probs = torch.softmax(logits, dim=-1) + probs_draft = torch.softmax(logits_draft, dim=-1) + gather = lambda probs, tokens: rearrange( + probs.gather(dim=-1, index=rearrange(tokens, "... -> ... 1")), "... 1 -> ..." + ) + # (batch, seqlen) + accepted = torch.rand(batch, seqlen, device=probs.device) * gather( + probs_draft, tokens_draft + ) <= gather(probs[:, :-1], tokens_draft) + accepted_all = accepted.all(dim=-1) + # (batch,) + first_rejected_idx = torch.where(accepted_all, seqlen, accepted.int().argmin(dim=-1)) + probs_diff = torch.clamp(probs[:, :-1] - probs_draft, min=0.0) + # torch.multinomial can deal with unnormalized probabilities + # probs_diff /= probs_diff.sum(dim=-1, keepdim=True) + resample_probs = torch.cat([probs_diff, probs[:, -1:]], dim=1) + resample_probs = rearrange( + resample_probs.gather(dim=1, index=repeat(first_rejected_idx, "b -> b 1 d", d=vocab_size)), + "b 1 d -> b d", + ) + resample = torch.multinomial(resample_probs, num_samples=1).squeeze(dim=-1) # (batch,) + tokens = F.pad(tokens_draft, (0, 1)) + tokens[:, first_rejected_idx] = resample + return tokens, first_rejected_idx + 1 + + +@torch.inference_mode() +def decode_speculative( + input_ids, + model, + model_draft, + max_length, + speculative_lookahead=3, + top_k=1, + top_p=0.0, + temperature=1.0, + eos_token_id=None, + vocab_size=None, + tensor_parallel=1, + cg=False, + enable_timing=False, + debug=False, +): + """ + TD: WIP, for my own understanding, lightly tested. Only support batch_size == 1 for now. + + Speculative decoding, either greedy or with top-k or top-p sampling. + If top-k = 0, don't limit the number of candidates (pure sampling). + Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first, + then top-p. + We assume that all sequences in the same batch have the same length. + + Arguments: + input_ids: (batch, seq_len) + max_length: int + Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields: + sequences: (batch, max_length) + scores: tuples of (batch, vocab_size) + """ + batch_size, seqlen_og = input_ids.shape + assert batch_size == 1, "Speculative decoding implementation only supports batch_size=1" + assert eos_token_id is None, "Speculative decoding implementation doesn't support eos_token_id" + if cg: + if not hasattr(model_draft, "_decoding_cache"): + model_draft._decoding_cache = None + model_draft._decoding_cache = update_graph_cache( + model_draft, + model_draft._decoding_cache, + batch_size, + seqlen_og, + max_length, + # draft model needs to process either 1 or 2 tokens at a time + decoding_seqlens=(1, 2), + tensor_parallel=tensor_parallel, + ) + inference_params_draft = model_draft._decoding_cache.inference_params + inference_params_draft.reset(max_length, batch_size) + if not hasattr(model, "_decoding_cache"): + model._decoding_cache = None + model._decoding_cache = update_graph_cache( + model, + model._decoding_cache, + batch_size, + seqlen_og, + max_length, + decoding_seqlens=range(1, speculative_lookahead + 2), + tensor_parallel=tensor_parallel, + ) + inference_params = model._decoding_cache.inference_params + inference_params.reset(max_length, batch_size) + else: + inference_params_draft = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size) + inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size) + + def get_logits(input_ids, inference_params, model, num_last_tokens=1, cg=False): + decoding = inference_params.seqlen_offset > 0 + if decoding: + seqlen = input_ids.shape[1] + # if inference_params.lengths_per_sample is None: + # TODO: in the case of batched decoding where each sequence has a different length, + # we need to compute the position_ids for each sequence using lengths_per_sample + if True: + cache_seqlens = torch.full( + (input_ids.shape[0],), + inference_params.seqlen_offset, + dtype=torch.int32, + device=input_ids.device, + ) + else: + cache_seqlens = inference_params.lengths_per_sample + position_ids = cache_seqlens[:, None] + torch.arange( + seqlen, dtype=torch.long, device=input_ids.device + ) + else: + position_ids = None + if not cg or not decoding: + logits = model( + input_ids, + position_ids=position_ids, + inference_params=inference_params, + num_last_tokens=num_last_tokens, + ).logits + else: + # NOTE: careful, CUDA graph is set to have num_last_tokens=input_ids.shape[1]. + # This might not be compatible the num_last_tokens used here. + assert num_last_tokens <= input_ids.shape[1] + logits = model._decoding_cache.run( + input_ids, position_ids, inference_params.seqlen_offset + )[:, -num_last_tokens:] + return logits[..., :vocab_size] if vocab_size is not None else logits + + def sample_tokens(input_ids, get_logits_fn, inference_params, sample_fn, num_tokens=1): + """Sample `num_tokens` tokens from the model, given the previous logits. + Also return the logits of the sampled tokens. + Arguments: + input_ids: (batch, seqlen) + Return: + tokens: (batch, num_tokens) + scores: (batch, num_tokens), which contains @previous_logits and the logits of the next + (num_tokens - 1) tokens. The logits of the last token isn't computed. + """ + assert num_tokens >= 1 + sequences, scores = [input_ids], [] + for i in range(num_tokens): + scores.append(get_logits_fn(sequences[-1], inference_params)[:, -1]) + inference_params.seqlen_offset += sequences[-1].shape[1] + sequences.append(sample_fn(scores[-1]).unsqueeze(1)) + return torch.cat(sequences[1:], dim=1), torch.stack(scores, dim=1) + + sampling_kwargs = dict(top_k=top_k, top_p=top_p, temperature=temperature) + sample_fn = partial(sample, **sampling_kwargs) + get_logits_main = partial(get_logits, model=model, cg=cg) + get_logits_draft = partial(get_logits, model=model_draft, cg=cg) + sample_tokens_main = partial( + sample_tokens, + get_logits_fn=get_logits_main, + sample_fn=sample_fn, + inference_params=inference_params, + ) + sample_tokens_draft = partial( + sample_tokens, + get_logits_fn=get_logits_draft, + sample_fn=sample_fn, + inference_params=inference_params_draft, + ) + + if debug: + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained("gpt2") + if enable_timing: + if tensor_parallel > 1: + torch.distributed.barrier() + torch.cuda.synchronize() + start = time.time() + + sequences, scores = [input_ids], [] + num_main_model_calls = 0 + num_draft_tokens = 0 + num_accepted_tokens_history = [] + if seqlen_og >= max_length - 1: + # Don't do speculative sampling, just sample 1 token from the model + tokens, scores_new = sample_tokens_main(input_ids, num_tokens=1) + sequences.append(tokens) + scores.append(scores_new) + else: + # Sample from draft model, which produces @n_spec_tokens, and @model + # will then use to produce between 1 and 1 + @n_spec_tokens tokens. + # We want seqlen_og + 1 + @n_spec_tokens to be <= @max_length. + n_spec_tokens = min(speculative_lookahead, max_length - seqlen_og - 1) + tokens_draft, scores_draft = sample_tokens_draft(input_ids, num_tokens=n_spec_tokens) + num_draft_tokens += n_spec_tokens + if debug: + scores_draft_ref = model_draft( + torch.cat([input_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1 + ).logits + print((scores_draft - scores_draft_ref[:, :-1]).abs().max()) + + # Evaluate the draft tokens with the model + logits = get_logits_main( + torch.cat([input_ids, tokens_draft], dim=1), + inference_params, + num_last_tokens=n_spec_tokens + 1, + ) + num_main_model_calls += 1 + if debug: + logits_ref = model( + torch.cat([input_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1 + ).logits + print((logits - logits_ref).abs().max()) + # breakpoint() + tokens, num_generated_tokens = sample_speculative( + logits, scores_draft, tokens_draft, **sampling_kwargs + ) + num_accepted_tokens_history.append(num_generated_tokens - 1) + if debug: + print(tokens) + print(num_generated_tokens) + # breakpoint() + # TODO: we're using the fact that batch_size == 1 + # TODO: check eos_token_id + sequences.append(tokens[:1, : num_generated_tokens[0]]) + scores.append(logits[:1, : num_generated_tokens[0]]) + # Note that @model has not evaluated the last sampled token yet, so we'll need to pass + # that in the next time we call @model. + num_generated = num_generated_tokens[0].item() + inference_params.seqlen_offset = seqlen_og + num_generated - 1 + inference_params_draft.seqlen_offset = ( + inference_params.seqlen_offset - 1 + if num_generated > 1 + else inference_params.seqlen_offset + ) + if debug: + cur_ids = torch.cat([input_ids, sequences[-1]], dim=1) + scores_ref = model(cur_ids, num_last_tokens=num_generated_tokens[0].item() + 1).logits + print((scores[-1] - scores_ref[:, :-1]).abs().max()) + # breakpoint() + + while True: + # seqlen_offset is total length generated - 1 + if inference_params.seqlen_offset >= max_length - 1: + break + if inference_params.seqlen_offset >= max_length - 2: + # Don't do speculative sampling, just sample 1 token from the model + tokens, scores_new = sample_tokens_main(sequences[-1][:, -1:], num_tokens=1) + sequences.append(tokens) + scores.append(scores_new) + break + # Sample from draft model + n_spec_tokens = min( + speculative_lookahead, max_length - inference_params_draft.seqlen_offset - 2 + ) + # If the main model accepts all the draft tokens, plus it samples one new token, + # then at the next iteration the draft model need to evaluate the logits of the last draft + # token and the logits of the newly sampled token. So here we pass in the last 2 tokens + # of sequences[-1]. + # This exception is when the main model rejects all the draft tokens, in which case we + # will only have 1 token to pass in. + tokens_draft, scores_draft = sample_tokens_draft( + sequences[-1][:, -2:], num_tokens=n_spec_tokens + ) + num_draft_tokens += n_spec_tokens + if debug: + scores_draft_ref = model_draft( + torch.cat([cur_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1 + ).logits + print((scores_draft - scores_draft_ref[:, :-1]).abs().max()) + # breakpoint() + # Evaluate the draft tokens with the model + logits = get_logits_main( + torch.cat([sequences[-1][:, -1:], tokens_draft], dim=1), + inference_params, + num_last_tokens=n_spec_tokens + 1, + ) # (batch, n_spec_tokens + 1, vocab_size) + num_main_model_calls += 1 + if debug: + logits_ref = model( + torch.cat([cur_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1 + ).logits + print((logits - logits_ref).abs().max()) + # breakpoint() + tokens, num_generated_tokens = sample_speculative( + logits, scores_draft, tokens_draft, **sampling_kwargs + ) + num_accepted_tokens_history.append(num_generated_tokens - 1) + if debug: + print(tokens) + print(num_generated_tokens) + # breakpoint() + sequences.append(tokens[:1, : num_generated_tokens[0]]) + scores.append(logits[:1, : num_generated_tokens[0]]) + # We've evaluated 1 token from sequences[-1][:, -1:] above, plus + # num_generated_tokens[0].item() - 1 tokens from the draft model. + num_generated = num_generated_tokens[0].item() + inference_params.seqlen_offset += num_generated + inference_params_draft.seqlen_offset = ( + inference_params.seqlen_offset - 1 + if num_generated > 1 + else inference_params.seqlen_offset + ) + if debug: + cur_ids = torch.cat([cur_ids, sequences[-1]], dim=1) + scores_ref = model(cur_ids, num_last_tokens=num_generated_tokens[0].item() + 1).logits + print((scores[-1] - scores_ref[:, :-1]).abs().max()) + # breakpoint() + + if enable_timing: + if tensor_parallel > 1: + torch.distributed.barrier() + torch.cuda.synchronize() + print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms") + print(f"Number of calls to main model: {num_main_model_calls}") + print( + f"Acceptance rate: {torch.cat(num_accepted_tokens_history).sum().item() / num_draft_tokens * 100:.2f}%" + ) + sequences = torch.cat(sequences, dim=1) + scores = torch.cat(scores, dim=1) + if debug: + scores_ref = model(sequences).logits + print((scores - scores_ref[:, seqlen_og - 1 : -1]).abs().max()) + output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput + return output_cls(sequences=sequences, scores=scores) + + +class GenerationMixin: + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + raise NotImplementedError + + def generate( + self, + input_ids, + max_length, + top_k=1, + top_p=0.0, + temperature=1.0, + return_dict_in_generate=False, + output_scores=False, + **kwargs, + ): + output = decode( + input_ids, self, max_length, top_k=top_k, top_p=top_p, temperature=temperature, **kwargs + ) + if not output_scores: + output.scores = None + return output if return_dict_in_generate else output.sequences + + +def allocate_inference_cache( + max_batch_size, + max_seqlen, + nheads, + headdim, + layers: Union[int, Sequence], + device, + dtype=torch.float16, +): + assert dtype in [torch.float16, torch.bfloat16, torch.float32] + kv_cache_shape = (max_batch_size, max_seqlen, 2, nheads, headdim) + if isinstance(layers, int): + layers = range(layers) + return {i: torch.empty(kv_cache_shape, device=device, dtype=dtype) for i in layers} + + +@dataclass +class DecodingCGCache: + max_batch_size: int = 0 + max_seqlen: int = 0 + device = None + dtype = None + callables: dict = field(default_factory=dict) + mempool = None + inference_params: Optional[InferenceParams] = None + run: Optional[Callable] = None + + +@torch.inference_mode() +def update_graph_cache( + model, + cache, + batch_size, + seqlen_og, + max_seqlen, + decoding_seqlens=(1,), + tensor_parallel=1, + dtype=None, + n_warmups=2, +): + if cache is None: + cache = DecodingCGCache() + param_example = next(iter(model.parameters())) + device = param_example.device + if dtype is None: + dtype = param_example.dtype + if ( + (device, dtype) != (cache.device, cache.dtype) + or batch_size > cache.max_batch_size + or max_seqlen > cache.max_seqlen + ): # Invalidate the cache + cache.callables = {} + cache.mempool = None + cache.inference_params = None + gc.collect() + cache.device, cache.dtype = device, dtype + cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen + if hasattr(model, "allocate_inference_cache"): + inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype) + else: + headdim = getattr( + model.config, + "head_dim", + model.config.hidden_size // model.config.num_attention_heads, + ) + inf_cache = allocate_inference_cache( + batch_size, + max_seqlen, + model.config.num_attention_heads // tensor_parallel, + headdim, + model.config.num_hidden_layers, + device, + dtype, + ) + lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device) + cache.inference_params = InferenceParams( + max_seqlen=max_seqlen, + max_batch_size=batch_size, + seqlen_offset=seqlen_og, + key_value_memory_dict=inf_cache, + lengths_per_sample=lengths_per_sample, + ) + cache.mempool = torch.cuda.graphs.graph_pool_handle() + for decoding_seqlen in decoding_seqlens: + if (batch_size, decoding_seqlen) not in cache.callables: + cache.callables[batch_size, decoding_seqlen] = capture_graph( + model, + cache.inference_params, + batch_size, + max_seqlen, + decoding_seqlen=decoding_seqlen, + mempool=cache.mempool, + n_warmups=n_warmups, + ) + + def dispatch(input_ids, position_ids, seqlen): + batch_size, decoding_seqlen = input_ids.shape[:2] + return cache.callables[batch_size, decoding_seqlen](input_ids, position_ids, seqlen) + + cache.run = dispatch + cache.inference_params.seqlen_offset = 0 # Reset so it's not confusing + return cache + + +def capture_graph( + model, inference_params, batch_size, max_seqlen, decoding_seqlen=1, mempool=None, n_warmups=2 +): + device = next(iter(model.parameters())).device + input_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device) + position_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device) + seqlen_offset_og = inference_params.seqlen_offset + inference_params.seqlen_offset = max_seqlen - decoding_seqlen + inference_params.lengths_per_sample[:] = inference_params.seqlen_offset + + # Warmup before capture + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + for _ in range(n_warmups): + logits = model( + input_ids, + position_ids=position_ids, + inference_params=inference_params, + num_last_tokens=decoding_seqlen, + ).logits + s.synchronize() + # This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0, + # which requires that graph launch and non-captured launch to not overlap (I think, + # that's how I interpret the documentation). I'm not sure if this is required. + if torch.distributed.is_initialized(): + torch.distributed.barrier() + torch.cuda.current_stream().wait_stream(s) + # Captures the graph + # To allow capture, automatically sets a side stream as the current stream in the context + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, pool=mempool): + logits = model( + input_ids, + position_ids=position_ids, + inference_params=inference_params, + num_last_tokens=decoding_seqlen, + ).logits + + def run(new_input_ids, new_position_ids, seqlen): + inference_params.lengths_per_sample[:] = seqlen + input_ids.copy_(new_input_ids) + position_ids.copy_(new_position_ids) + graph.replay() + return logits.clone() + + inference_params.seqlen_offset = seqlen_offset_og + return run diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/utils/pretrained.py b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/utils/pretrained.py new file mode 100644 index 0000000000000000000000000000000000000000..40e76bd2692335c7f474f6b6479be67eb95f8d20 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/utils/pretrained.py @@ -0,0 +1,79 @@ +import os +from functools import partial + +import torch +from safetensors.torch import load_file as safe_load_file +from transformers.utils import ( + SAFE_WEIGHTS_INDEX_NAME, + SAFE_WEIGHTS_NAME, + WEIGHTS_INDEX_NAME, + WEIGHTS_NAME, +) +from transformers.utils.hub import cached_file, get_checkpoint_shard_files + + +def state_dict_from_pretrained(model_name, device=None, dtype=None): + # If not fp32, then we don't want to load directly to the GPU + mapped_device = "cpu" if dtype not in [torch.float32, None] else device + is_sharded = False + load_safe = False + resolved_archive_file = None + + weights_path = os.path.join(model_name, WEIGHTS_NAME) + weights_index_path = os.path.join(model_name, WEIGHTS_INDEX_NAME) + safe_weights_path = os.path.join(model_name, SAFE_WEIGHTS_NAME) + safe_weights_index_path = os.path.join(model_name, SAFE_WEIGHTS_INDEX_NAME) + + if os.path.isfile(weights_path): + resolved_archive_file = cached_file( + model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False + ) + elif os.path.isfile(weights_index_path): + resolved_archive_file = cached_file( + model_name, WEIGHTS_INDEX_NAME, _raise_exceptions_for_missing_entries=False + ) + is_sharded = True + elif os.path.isfile(safe_weights_path): + resolved_archive_file = cached_file( + model_name, SAFE_WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False + ) + load_safe = True + elif os.path.isfile(safe_weights_index_path): + resolved_archive_file = cached_file( + model_name, SAFE_WEIGHTS_INDEX_NAME, _raise_exceptions_for_missing_entries=False + ) + is_sharded = True + load_safe = True + else: # Try loading from HF hub instead of from local files + resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, + _raise_exceptions_for_missing_entries=False) + if resolved_archive_file is None: + resolved_archive_file = cached_file(model_name, WEIGHTS_INDEX_NAME, + _raise_exceptions_for_missing_entries=False) + if resolved_archive_file is not None: + is_sharded = True + + if resolved_archive_file is None: + raise EnvironmentError(f"Model name {model_name} was not found.") + + if load_safe: + loader = partial(safe_load_file, device=mapped_device) + else: + loader = partial(torch.load, map_location=mapped_device) + + if is_sharded: + # resolved_archive_file becomes a list of files that point to the different + # checkpoint shards in this case. + resolved_archive_file, sharded_metadata = get_checkpoint_shard_files( + model_name, resolved_archive_file + ) + state_dict = {} + for sharded_file in resolved_archive_file: + state_dict.update(loader(sharded_file)) + else: + state_dict = loader(resolved_archive_file) + # Convert dtype before moving to GPU to save memory + if dtype is not None: + state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()} + state_dict = {k: v.to(device=device) for k, v in state_dict.items()} + return state_dict diff --git a/llm_tutorial/llm_recipes/models/hf-model-eval/llm-jp-v3-3.7b_ja-en_3M-pairs_3.5e-5/iter_0000698/model-00004-of-00004.safetensors b/llm_tutorial/llm_recipes/models/hf-model-eval/llm-jp-v3-3.7b_ja-en_3M-pairs_3.5e-5/iter_0000698/model-00004-of-00004.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..f5d3db38b3548c188817aa7912daf98b3782116d --- /dev/null +++ b/llm_tutorial/llm_recipes/models/hf-model-eval/llm-jp-v3-3.7b_ja-en_3M-pairs_3.5e-5/iter_0000698/model-00004-of-00004.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:43d65a1f8bca72079e5fd8d5d61a739fbb597854f9c1e1cac01d74b58a25b38a +size 1223688320