Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .venv/lib/python3.11/site-packages/xformers/_flash_attn/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/_flash_attn/__pycache__/bert_padding.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/_flash_attn/__pycache__/flash_attn_interface.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/_flash_attn/__pycache__/flash_attn_triton.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/_flash_attn/__pycache__/flash_attn_triton_og.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/_flash_attn/__pycache__/flash_blocksparse_attention.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/_flash_attn/__pycache__/flash_blocksparse_attn_interface.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/_flash_attn/__pycache__/fused_softmax.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/_flash_attn/layers/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/xformers/_flash_attn/layers/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/_flash_attn/layers/__pycache__/patch_embed.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/_flash_attn/layers/__pycache__/rotary.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/_flash_attn/layers/patch_embed.py +67 -0
- .venv/lib/python3.11/site-packages/xformers/_flash_attn/layers/rotary.py +481 -0
- .venv/lib/python3.11/site-packages/xformers/_flash_attn/losses/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/xformers/_flash_attn/losses/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/_flash_attn/losses/__pycache__/cross_entropy.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/_flash_attn/losses/cross_entropy.py +85 -0
- .venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/baichuan.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/bert.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/bigcode.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/btlm.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/falcon.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/gpt.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/gpt_neox.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/gptj.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/llama.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/opt.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/vit.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/_flash_attn/models/baichuan.py +151 -0
- .venv/lib/python3.11/site-packages/xformers/_flash_attn/models/bert.py +764 -0
- .venv/lib/python3.11/site-packages/xformers/_flash_attn/models/bigcode.py +233 -0
- .venv/lib/python3.11/site-packages/xformers/_flash_attn/models/btlm.py +102 -0
- .venv/lib/python3.11/site-packages/xformers/_flash_attn/models/falcon.py +143 -0
- .venv/lib/python3.11/site-packages/xformers/_flash_attn/models/gpt.py +1080 -0
- .venv/lib/python3.11/site-packages/xformers/_flash_attn/models/gpt_neox.py +124 -0
- .venv/lib/python3.11/site-packages/xformers/_flash_attn/models/gptj.py +109 -0
- .venv/lib/python3.11/site-packages/xformers/_flash_attn/models/llama.py +422 -0
- .venv/lib/python3.11/site-packages/xformers/_flash_attn/models/opt.py +116 -0
- .venv/lib/python3.11/site-packages/xformers/_flash_attn/models/vit.py +373 -0
- .venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/__pycache__/activations.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/__pycache__/fused_dense.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/__pycache__/layer_norm.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/__pycache__/rms_norm.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/activations.py +135 -0
- .venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/fused_dense.py +688 -0
.venv/lib/python3.11/site-packages/xformers/_flash_attn/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (606 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/_flash_attn/__pycache__/bert_padding.cpython-311.pyc
ADDED
|
Binary file (11 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/_flash_attn/__pycache__/flash_attn_interface.cpython-311.pyc
ADDED
|
Binary file (46.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/_flash_attn/__pycache__/flash_attn_triton.cpython-311.pyc
ADDED
|
Binary file (44.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/_flash_attn/__pycache__/flash_attn_triton_og.cpython-311.pyc
ADDED
|
Binary file (16.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/_flash_attn/__pycache__/flash_blocksparse_attention.cpython-311.pyc
ADDED
|
Binary file (8.15 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/_flash_attn/__pycache__/flash_blocksparse_attn_interface.cpython-311.pyc
ADDED
|
Binary file (7.53 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/_flash_attn/__pycache__/fused_softmax.cpython-311.pyc
ADDED
|
Binary file (9.45 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/_flash_attn/layers/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/xformers/_flash_attn/layers/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (200 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/_flash_attn/layers/__pycache__/patch_embed.cpython-311.pyc
ADDED
|
Binary file (3.34 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/_flash_attn/layers/__pycache__/rotary.cpython-311.pyc
ADDED
|
Binary file (20.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/_flash_attn/layers/patch_embed.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# We use the same API as https://github.com/rwightman/pytorch-image-models/blob/v0.6.11/timm/models/layers/patch_embed.py
|
| 2 |
+
# But we use nn.Linear instead of Conv2d and it's about 8x faster.
|
| 3 |
+
|
| 4 |
+
from functools import partial
|
| 5 |
+
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from einops import rearrange
|
| 8 |
+
from torch import _assert
|
| 9 |
+
from torch.nn.modules.utils import _pair
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
from flash_attn.ops.fused_dense import FusedDense
|
| 13 |
+
except ImportError:
|
| 14 |
+
FusedDense = None
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class PatchEmbed(nn.Module):
|
| 18 |
+
"""2D Image to Patch Embedding"""
|
| 19 |
+
|
| 20 |
+
def __init__(
|
| 21 |
+
self,
|
| 22 |
+
img_size=224,
|
| 23 |
+
patch_size=16,
|
| 24 |
+
in_chans=3,
|
| 25 |
+
embed_dim=768,
|
| 26 |
+
norm_layer=None,
|
| 27 |
+
flatten=True,
|
| 28 |
+
bias=True,
|
| 29 |
+
fused_bias_fc=False,
|
| 30 |
+
):
|
| 31 |
+
super().__init__()
|
| 32 |
+
img_size = _pair(img_size)
|
| 33 |
+
patch_size = _pair(patch_size)
|
| 34 |
+
self.img_size = img_size
|
| 35 |
+
self.patch_size = patch_size
|
| 36 |
+
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
| 37 |
+
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
| 38 |
+
self.flatten = flatten
|
| 39 |
+
if fused_bias_fc and FusedDense is None:
|
| 40 |
+
raise ImportError("fused_dense is not installed")
|
| 41 |
+
|
| 42 |
+
linear_cls = nn.Linear if not fused_bias_fc or not bias else FusedDense
|
| 43 |
+
self.proj = linear_cls(in_chans * patch_size[0] * patch_size[1], embed_dim, bias=bias)
|
| 44 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
| 45 |
+
|
| 46 |
+
def forward(self, x):
|
| 47 |
+
_, _, H, W = x.shape
|
| 48 |
+
_assert(
|
| 49 |
+
H == self.img_size[0],
|
| 50 |
+
f"Input image height ({H}) doesn't match model ({self.img_size[0]}).",
|
| 51 |
+
)
|
| 52 |
+
_assert(
|
| 53 |
+
W == self.img_size[1],
|
| 54 |
+
f"Input image width ({W}) doesn't match model ({self.img_size[1]}).",
|
| 55 |
+
)
|
| 56 |
+
x = self.proj(
|
| 57 |
+
rearrange(
|
| 58 |
+
x,
|
| 59 |
+
"b c (h p1) (w p2) -> b h w (c p1 p2)",
|
| 60 |
+
p1=self.patch_size[0],
|
| 61 |
+
p2=self.patch_size[1],
|
| 62 |
+
)
|
| 63 |
+
)
|
| 64 |
+
if self.flatten:
|
| 65 |
+
x = rearrange(x, "b h w c -> b (h w) c")
|
| 66 |
+
x = self.norm(x)
|
| 67 |
+
return x
|
.venv/lib/python3.11/site-packages/xformers/_flash_attn/layers/rotary.py
ADDED
|
@@ -0,0 +1,481 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023, Tri Dao.
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
from typing import Optional, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from einops import rearrange, repeat
|
| 8 |
+
from flash_attn.ops.triton.rotary import apply_rotary
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def rotate_half(x, interleaved=False):
|
| 12 |
+
if not interleaved:
|
| 13 |
+
x1, x2 = x.chunk(2, dim=-1)
|
| 14 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 15 |
+
else:
|
| 16 |
+
x1, x2 = x[..., ::2], x[..., 1::2]
|
| 17 |
+
return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
|
| 21 |
+
"""
|
| 22 |
+
x: (batch_size, seqlen, nheads, headdim)
|
| 23 |
+
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
|
| 24 |
+
"""
|
| 25 |
+
ro_dim = cos.shape[-1] * 2
|
| 26 |
+
assert ro_dim <= x.shape[-1]
|
| 27 |
+
cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
|
| 28 |
+
sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
|
| 29 |
+
return torch.cat(
|
| 30 |
+
[x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]],
|
| 31 |
+
dim=-1,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class ApplyRotaryEmb(torch.autograd.Function):
|
| 36 |
+
@staticmethod
|
| 37 |
+
def forward(
|
| 38 |
+
ctx,
|
| 39 |
+
x,
|
| 40 |
+
cos,
|
| 41 |
+
sin,
|
| 42 |
+
interleaved=False,
|
| 43 |
+
inplace=False,
|
| 44 |
+
seqlen_offsets: Union[int, torch.Tensor] = 0,
|
| 45 |
+
cu_seqlens: Optional[torch.Tensor] = None,
|
| 46 |
+
max_seqlen: Optional[int] = None,
|
| 47 |
+
):
|
| 48 |
+
out = apply_rotary(
|
| 49 |
+
x,
|
| 50 |
+
cos,
|
| 51 |
+
sin,
|
| 52 |
+
seqlen_offsets=seqlen_offsets,
|
| 53 |
+
cu_seqlens=cu_seqlens,
|
| 54 |
+
max_seqlen=max_seqlen,
|
| 55 |
+
interleaved=interleaved,
|
| 56 |
+
inplace=inplace,
|
| 57 |
+
)
|
| 58 |
+
if isinstance(seqlen_offsets, int):
|
| 59 |
+
ctx.save_for_backward(cos, sin, cu_seqlens) # Can't save int with save_for_backward
|
| 60 |
+
ctx.seqlen_offsets = seqlen_offsets
|
| 61 |
+
else:
|
| 62 |
+
ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)
|
| 63 |
+
ctx.seqlen_offsets = None
|
| 64 |
+
ctx.interleaved = interleaved
|
| 65 |
+
ctx.inplace = inplace
|
| 66 |
+
ctx.max_seqlen = max_seqlen
|
| 67 |
+
return out if not inplace else x
|
| 68 |
+
|
| 69 |
+
@staticmethod
|
| 70 |
+
def backward(ctx, do):
|
| 71 |
+
seqlen_offsets = ctx.seqlen_offsets
|
| 72 |
+
if seqlen_offsets is None:
|
| 73 |
+
cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors
|
| 74 |
+
else:
|
| 75 |
+
cos, sin, cu_seqlens = ctx.saved_tensors
|
| 76 |
+
# TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with
|
| 77 |
+
# "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works.
|
| 78 |
+
if not ctx.interleaved and not ctx.inplace:
|
| 79 |
+
do = do.clone()
|
| 80 |
+
dx = apply_rotary(
|
| 81 |
+
do,
|
| 82 |
+
cos,
|
| 83 |
+
sin,
|
| 84 |
+
seqlen_offsets=seqlen_offsets,
|
| 85 |
+
cu_seqlens=cu_seqlens,
|
| 86 |
+
max_seqlen=ctx.max_seqlen,
|
| 87 |
+
interleaved=ctx.interleaved,
|
| 88 |
+
inplace=ctx.inplace,
|
| 89 |
+
conjugate=True,
|
| 90 |
+
)
|
| 91 |
+
return dx, None, None, None, None, None, None, None
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def apply_rotary_emb(
|
| 95 |
+
x,
|
| 96 |
+
cos,
|
| 97 |
+
sin,
|
| 98 |
+
interleaved=False,
|
| 99 |
+
inplace=False,
|
| 100 |
+
seqlen_offsets: Union[int, torch.Tensor] = 0,
|
| 101 |
+
cu_seqlens: Optional[torch.Tensor] = None,
|
| 102 |
+
max_seqlen: Optional[int] = None,
|
| 103 |
+
):
|
| 104 |
+
"""
|
| 105 |
+
Arguments:
|
| 106 |
+
x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
|
| 107 |
+
else (total_seqlen, nheads, headdim)
|
| 108 |
+
cos, sin: (seqlen_rotary, rotary_dim / 2)
|
| 109 |
+
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
|
| 110 |
+
of 1st half and 2nd half (GPT-NeoX style).
|
| 111 |
+
inplace: if True, apply rotary embedding in-place.
|
| 112 |
+
seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
|
| 113 |
+
Most commonly used in inference when we have KV cache.
|
| 114 |
+
cu_seqlens: (batch + 1,) or None
|
| 115 |
+
max_seqlen: int
|
| 116 |
+
Return:
|
| 117 |
+
out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
|
| 118 |
+
else (total_seqlen, nheads, headdim)
|
| 119 |
+
rotary_dim must be <= headdim
|
| 120 |
+
Apply rotary embedding to the first rotary_dim of x.
|
| 121 |
+
"""
|
| 122 |
+
return ApplyRotaryEmb.apply(
|
| 123 |
+
x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
# For backward compatibility
|
| 128 |
+
apply_rotary_emb_func = apply_rotary_emb
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class ApplyRotaryEmbQKV_(torch.autograd.Function):
|
| 132 |
+
@staticmethod
|
| 133 |
+
def forward(
|
| 134 |
+
ctx,
|
| 135 |
+
qkv,
|
| 136 |
+
cos,
|
| 137 |
+
sin,
|
| 138 |
+
cos_k=None,
|
| 139 |
+
sin_k=None,
|
| 140 |
+
interleaved=False,
|
| 141 |
+
seqlen_offsets: Union[int, torch.Tensor] = 0,
|
| 142 |
+
):
|
| 143 |
+
batch, seqlen, three, nheads, headdim = qkv.shape
|
| 144 |
+
assert three == 3
|
| 145 |
+
if cos_k is None and sin_k is None and qkv.is_contiguous():
|
| 146 |
+
# Call 1 kernel instead of 2 kernels
|
| 147 |
+
# We need qkv to be contiguous so that when we reshape to combine (3, nheads)
|
| 148 |
+
# dimensions, we get the same tensor
|
| 149 |
+
# qk = rearrange(qkv[:, :, :2], "b s t h d -> b s (t h) d")
|
| 150 |
+
qk = qkv[:, :, :2].reshape(batch, seqlen, -1, headdim)
|
| 151 |
+
apply_rotary(
|
| 152 |
+
qk, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True
|
| 153 |
+
)
|
| 154 |
+
else:
|
| 155 |
+
cos_k = cos if cos_k is None else cos_k
|
| 156 |
+
sin_k = sin if sin_k is None else sin_k
|
| 157 |
+
q, k = qkv[:, :, 0], qkv[:, :, 1]
|
| 158 |
+
apply_rotary(q, cos, sin, seqlen_offsets, interleaved=interleaved, inplace=True)
|
| 159 |
+
apply_rotary(k, cos_k, sin_k, seqlen_offsets, interleaved=interleaved, inplace=True)
|
| 160 |
+
ctx.save_for_backward(cos, sin, cos_k, sin_k)
|
| 161 |
+
if isinstance(seqlen_offsets, int):
|
| 162 |
+
ctx.save_for_backward(cos, sin, cos_k, sin_k)
|
| 163 |
+
ctx.seqlen_offsets = seqlen_offsets
|
| 164 |
+
else:
|
| 165 |
+
ctx.save_for_backward(cos, sin, cos_k, sin_k, seqlen_offsets)
|
| 166 |
+
ctx.seqlen_offsets = None
|
| 167 |
+
ctx.interleaved = interleaved
|
| 168 |
+
return qkv
|
| 169 |
+
|
| 170 |
+
@staticmethod
|
| 171 |
+
def backward(ctx, dqkv):
|
| 172 |
+
seqlen_offsets = ctx.seqlen_offsets
|
| 173 |
+
if seqlen_offsets is None:
|
| 174 |
+
cos, sin, cos_k, sin_k, seqlen_offsets = ctx.saved_tensors
|
| 175 |
+
else:
|
| 176 |
+
cos, sin, cos_k, sin_k = ctx.saved_tensors
|
| 177 |
+
if cos_k is None and sin_k is None and dqkv.is_contiguous():
|
| 178 |
+
# Call 1 kernel instead of 2 kernels
|
| 179 |
+
# We need dqkv to be contiguous so that when we reshape to combine (3, nheads)
|
| 180 |
+
# dimensions, we get the same tensor
|
| 181 |
+
dqk = rearrange(dqkv[:, :, :2], "b s t h d -> b s (t h) d")
|
| 182 |
+
apply_rotary(
|
| 183 |
+
dqk,
|
| 184 |
+
cos,
|
| 185 |
+
sin,
|
| 186 |
+
seqlen_offsets=seqlen_offsets,
|
| 187 |
+
interleaved=ctx.interleaved,
|
| 188 |
+
inplace=True,
|
| 189 |
+
conjugate=True,
|
| 190 |
+
)
|
| 191 |
+
else:
|
| 192 |
+
cos_k = cos if cos_k is None else cos_k
|
| 193 |
+
sin_k = sin if sin_k is None else sin_k
|
| 194 |
+
dq, dk = dqkv[:, :, 0], dqkv[:, :, 1]
|
| 195 |
+
apply_rotary(
|
| 196 |
+
dq, cos, sin, seqlen_offsets, interleaved=ctx.interleaved, inplace=True, conjugate=True
|
| 197 |
+
)
|
| 198 |
+
apply_rotary(
|
| 199 |
+
dk,
|
| 200 |
+
cos_k,
|
| 201 |
+
sin_k,
|
| 202 |
+
seqlen_offsets,
|
| 203 |
+
interleaved=ctx.interleaved,
|
| 204 |
+
inplace=True,
|
| 205 |
+
conjugate=True,
|
| 206 |
+
)
|
| 207 |
+
return dqkv, None, None, None, None, None, None
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def apply_rotary_emb_qkv_(
|
| 211 |
+
qkv,
|
| 212 |
+
cos,
|
| 213 |
+
sin,
|
| 214 |
+
cos_k=None,
|
| 215 |
+
sin_k=None,
|
| 216 |
+
interleaved=False,
|
| 217 |
+
seqlen_offsets: Union[int, torch.Tensor] = 0,
|
| 218 |
+
):
|
| 219 |
+
"""
|
| 220 |
+
Arguments:
|
| 221 |
+
qkv: (batch_size, seqlen, 3, nheads, headdim)
|
| 222 |
+
cos, sin: (seqlen, rotary_dim / 2)
|
| 223 |
+
cos_k, sin_k: (seqlen, rotary_dim / 2), optional
|
| 224 |
+
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
|
| 225 |
+
1st half and 2nd half (GPT-NeoX style).
|
| 226 |
+
seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount.
|
| 227 |
+
Most commonly used in inference when we have KV cache.
|
| 228 |
+
Return:
|
| 229 |
+
qkv: (batch_size, seqlen, 3, nheads, headdim)
|
| 230 |
+
rotary_dim must be <= headdim
|
| 231 |
+
Apply rotary embedding *inplace* to the first rotary_dim of Q and K.
|
| 232 |
+
"""
|
| 233 |
+
return ApplyRotaryEmbQKV_.apply(qkv, cos, sin, cos_k, sin_k, interleaved, seqlen_offsets)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
class ApplyRotaryEmbKV_(torch.autograd.Function):
|
| 237 |
+
@staticmethod
|
| 238 |
+
def forward(ctx, kv, cos, sin, interleaved=False, seqlen_offsets: Union[int, torch.Tensor] = 0):
|
| 239 |
+
batch, seqlen, two, nheads, headdim = kv.shape
|
| 240 |
+
assert two == 2
|
| 241 |
+
k = kv[:, :, 0]
|
| 242 |
+
apply_rotary(
|
| 243 |
+
k, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True
|
| 244 |
+
)
|
| 245 |
+
if isinstance(seqlen_offsets, int):
|
| 246 |
+
ctx.save_for_backward(cos, sin) # Can't save int with save_for_backward
|
| 247 |
+
ctx.seqlen_offsets = seqlen_offsets
|
| 248 |
+
else:
|
| 249 |
+
ctx.save_for_backward(cos, sin, seqlen_offsets)
|
| 250 |
+
ctx.seqlen_offsets = None
|
| 251 |
+
ctx.interleaved = interleaved
|
| 252 |
+
return kv
|
| 253 |
+
|
| 254 |
+
@staticmethod
|
| 255 |
+
def backward(ctx, dkv):
|
| 256 |
+
seqlen_offsets = ctx.seqlen_offsets
|
| 257 |
+
if seqlen_offsets is None:
|
| 258 |
+
cos, sin, seqlen_offsets = ctx.saved_tensors
|
| 259 |
+
else:
|
| 260 |
+
cos, sin = ctx.saved_tensors
|
| 261 |
+
apply_rotary(
|
| 262 |
+
dkv[:, :, 0],
|
| 263 |
+
cos,
|
| 264 |
+
sin,
|
| 265 |
+
seqlen_offsets=seqlen_offsets,
|
| 266 |
+
interleaved=ctx.interleaved,
|
| 267 |
+
inplace=True,
|
| 268 |
+
conjugate=True,
|
| 269 |
+
)
|
| 270 |
+
return dkv, None, None, None, None
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
apply_rotary_emb_kv_ = ApplyRotaryEmbKV_.apply
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def apply_rotary_emb_kv_(
|
| 277 |
+
kv,
|
| 278 |
+
cos,
|
| 279 |
+
sin,
|
| 280 |
+
interleaved=False,
|
| 281 |
+
seqlen_offsets: Union[int, torch.Tensor] = 0,
|
| 282 |
+
):
|
| 283 |
+
"""
|
| 284 |
+
Arguments:
|
| 285 |
+
kv: (batch_size, seqlen, 2, nheads, headdim)
|
| 286 |
+
cos, sin: (seqlen, rotary_dim / 2)
|
| 287 |
+
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
|
| 288 |
+
1st half and 2nd half (GPT-NeoX style).
|
| 289 |
+
seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount.
|
| 290 |
+
Most commonly used in inference when we have KV cache.
|
| 291 |
+
Return:
|
| 292 |
+
kv: (batch_size, seqlen, 2, nheads, headdim)
|
| 293 |
+
rotary_dim must be <= headdim
|
| 294 |
+
Apply rotary embedding *inplace* to the first rotary_dim of K.
|
| 295 |
+
"""
|
| 296 |
+
return ApplyRotaryEmbKV_.apply(kv, cos, sin, interleaved, seqlen_offsets)
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
class RotaryEmbedding(torch.nn.Module):
|
| 300 |
+
"""
|
| 301 |
+
The rotary position embeddings from RoFormer_ (Su et. al).
|
| 302 |
+
A crucial insight from the method is that the query and keys are
|
| 303 |
+
transformed by rotation matrices which depend on the relative positions.
|
| 304 |
+
|
| 305 |
+
Other implementations are available in the Rotary Transformer repo_ and in
|
| 306 |
+
GPT-NeoX_, GPT-NeoX was an inspiration
|
| 307 |
+
|
| 308 |
+
.. _RoFormer: https://arxiv.org/abs/2104.09864
|
| 309 |
+
.. _repo: https://github.com/ZhuiyiTechnology/roformer
|
| 310 |
+
.. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
|
| 311 |
+
|
| 312 |
+
If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
|
| 313 |
+
A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96
|
| 314 |
+
Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
|
| 315 |
+
"""
|
| 316 |
+
|
| 317 |
+
def __init__(
|
| 318 |
+
self,
|
| 319 |
+
dim: int,
|
| 320 |
+
base=10000.0,
|
| 321 |
+
interleaved=False,
|
| 322 |
+
scale_base=None,
|
| 323 |
+
pos_idx_in_fp32=True,
|
| 324 |
+
device=None,
|
| 325 |
+
):
|
| 326 |
+
"""
|
| 327 |
+
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
|
| 328 |
+
of 1st half and 2nd half (GPT-NeoX style).
|
| 329 |
+
pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
|
| 330 |
+
otherwise they might be in lower precision.
|
| 331 |
+
This option was added because previously (before 2023-07-02), when we construct
|
| 332 |
+
the position indices, we use the dtype of self.inv_freq. In most cases this would
|
| 333 |
+
be fp32, but if the model is trained in pure bf16 (not mixed precision), then
|
| 334 |
+
self.inv_freq would be bf16, and the position indices are also in bf16.
|
| 335 |
+
Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
|
| 336 |
+
embeddings for some positions will coincide.
|
| 337 |
+
To maintain compatibility with models previously trained in pure bf16,
|
| 338 |
+
we add this option.
|
| 339 |
+
"""
|
| 340 |
+
super().__init__()
|
| 341 |
+
self.dim = dim
|
| 342 |
+
self.base = float(base)
|
| 343 |
+
self.pos_idx_in_fp32 = pos_idx_in_fp32
|
| 344 |
+
# Generate and save the inverse frequency buffer (non trainable)
|
| 345 |
+
inv_freq = self._compute_inv_freq(device)
|
| 346 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 347 |
+
self.interleaved = interleaved
|
| 348 |
+
self.scale_base = scale_base
|
| 349 |
+
scale = (
|
| 350 |
+
(torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
|
| 351 |
+
if scale_base is not None
|
| 352 |
+
else None
|
| 353 |
+
)
|
| 354 |
+
self.register_buffer("scale", scale, persistent=False)
|
| 355 |
+
|
| 356 |
+
self._seq_len_cached = 0
|
| 357 |
+
self._cos_cached = None
|
| 358 |
+
self._sin_cached = None
|
| 359 |
+
self._cos_k_cached = None
|
| 360 |
+
self._sin_k_cached = None
|
| 361 |
+
|
| 362 |
+
def _compute_inv_freq(self, device=None):
|
| 363 |
+
return 1.0 / (
|
| 364 |
+
self.base
|
| 365 |
+
** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
|
| 369 |
+
# Reset the tables if the sequence length has changed,
|
| 370 |
+
# if we're on a new device (possibly due to tracing for instance),
|
| 371 |
+
# or if we're switching from inference mode to training
|
| 372 |
+
if (
|
| 373 |
+
seqlen > self._seq_len_cached
|
| 374 |
+
or self._cos_cached is None
|
| 375 |
+
or self._cos_cached.device != device
|
| 376 |
+
or self._cos_cached.dtype != dtype
|
| 377 |
+
or (self.training and self._cos_cached.is_inference())
|
| 378 |
+
):
|
| 379 |
+
self._seq_len_cached = seqlen
|
| 380 |
+
# We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
|
| 381 |
+
# And the output of arange can be quite large, so bf16 would lose a lot of precision.
|
| 382 |
+
# However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
|
| 383 |
+
if self.pos_idx_in_fp32:
|
| 384 |
+
t = torch.arange(seqlen, device=device, dtype=torch.float32)
|
| 385 |
+
# We want fp32 here as well since inv_freq will be multiplied with t, and the output
|
| 386 |
+
# will be large. Having it in bf16 will lose a lot of precision and cause the
|
| 387 |
+
# cos & sin output to change significantly.
|
| 388 |
+
# We want to recompute self.inv_freq if it was not loaded in fp32
|
| 389 |
+
if self.inv_freq.dtype != torch.float32:
|
| 390 |
+
inv_freq = self._compute_inv_freq(device=device)
|
| 391 |
+
else:
|
| 392 |
+
inv_freq = self.inv_freq
|
| 393 |
+
else:
|
| 394 |
+
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
| 395 |
+
inv_freq = self.inv_freq
|
| 396 |
+
# Don't do einsum, it converts fp32 to fp16 under AMP
|
| 397 |
+
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
| 398 |
+
freqs = torch.outer(t, inv_freq)
|
| 399 |
+
if self.scale is None:
|
| 400 |
+
self._cos_cached = torch.cos(freqs).to(dtype)
|
| 401 |
+
self._sin_cached = torch.sin(freqs).to(dtype)
|
| 402 |
+
else:
|
| 403 |
+
power = (
|
| 404 |
+
torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
|
| 405 |
+
- seqlen // 2
|
| 406 |
+
) / self.scale_base
|
| 407 |
+
scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
|
| 408 |
+
# We want the multiplication by scale to happen in fp32
|
| 409 |
+
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
|
| 410 |
+
self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
|
| 411 |
+
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
|
| 412 |
+
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
|
| 413 |
+
|
| 414 |
+
def forward(
|
| 415 |
+
self,
|
| 416 |
+
qkv: torch.Tensor,
|
| 417 |
+
kv: Optional[torch.Tensor] = None,
|
| 418 |
+
seqlen_offset: Union[int, torch.Tensor] = 0,
|
| 419 |
+
max_seqlen: Optional[int] = None,
|
| 420 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 421 |
+
"""
|
| 422 |
+
qkv: (batch, seqlen, 3, nheads, headdim) if kv is none,
|
| 423 |
+
else it's just q of shape (batch, seqlen, nheads, headdim)
|
| 424 |
+
kv: (batch, seqlen, 2, nheads, headdim)
|
| 425 |
+
seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount.
|
| 426 |
+
Most commonly used in inference when we have KV cache.
|
| 427 |
+
If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one
|
| 428 |
+
should pass in max_seqlen, which will update the cos / sin cache up to that length.
|
| 429 |
+
Apply rotary embedding *inplace* to qkv and / or kv.
|
| 430 |
+
"""
|
| 431 |
+
seqlen = qkv.shape[1]
|
| 432 |
+
if max_seqlen is not None:
|
| 433 |
+
self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
|
| 434 |
+
elif isinstance(seqlen_offset, int):
|
| 435 |
+
self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
|
| 436 |
+
if kv is None:
|
| 437 |
+
if self.scale is None:
|
| 438 |
+
return apply_rotary_emb_qkv_(
|
| 439 |
+
qkv,
|
| 440 |
+
self._cos_cached,
|
| 441 |
+
self._sin_cached,
|
| 442 |
+
interleaved=self.interleaved,
|
| 443 |
+
seqlen_offsets=seqlen_offset,
|
| 444 |
+
)
|
| 445 |
+
else:
|
| 446 |
+
return apply_rotary_emb_qkv_(
|
| 447 |
+
qkv,
|
| 448 |
+
self._cos_cached,
|
| 449 |
+
self._sin_cached,
|
| 450 |
+
self._cos_k_cached,
|
| 451 |
+
self._sin_k_cached,
|
| 452 |
+
interleaved=self.interleaved,
|
| 453 |
+
seqlen_offsets=seqlen_offset,
|
| 454 |
+
)
|
| 455 |
+
else:
|
| 456 |
+
q = qkv
|
| 457 |
+
q = apply_rotary_emb_func(
|
| 458 |
+
q,
|
| 459 |
+
self._cos_cached,
|
| 460 |
+
self._sin_cached,
|
| 461 |
+
interleaved=self.interleaved,
|
| 462 |
+
inplace=True,
|
| 463 |
+
seqlen_offsets=seqlen_offset,
|
| 464 |
+
)
|
| 465 |
+
if self.scale is None:
|
| 466 |
+
kv = apply_rotary_emb_kv_(
|
| 467 |
+
kv,
|
| 468 |
+
self._cos_cached,
|
| 469 |
+
self._sin_cached,
|
| 470 |
+
interleaved=self.interleaved,
|
| 471 |
+
seqlen_offsets=seqlen_offset,
|
| 472 |
+
)
|
| 473 |
+
else:
|
| 474 |
+
kv = apply_rotary_emb_kv_(
|
| 475 |
+
kv,
|
| 476 |
+
self._cos_k_cached,
|
| 477 |
+
self._sin_k_cached,
|
| 478 |
+
interleaved=self.interleaved,
|
| 479 |
+
seqlen_offsets=seqlen_offset,
|
| 480 |
+
)
|
| 481 |
+
return q, kv
|
.venv/lib/python3.11/site-packages/xformers/_flash_attn/losses/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/xformers/_flash_attn/losses/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (200 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/_flash_attn/losses/__pycache__/cross_entropy.cpython-311.pyc
ADDED
|
Binary file (3.89 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/_flash_attn/losses/cross_entropy.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024, Tri Dao.
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
from flash_attn.ops.triton.cross_entropy import cross_entropy_loss
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class CrossEntropyLoss(nn.Module):
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
ignore_index=-100,
|
| 13 |
+
reduction="mean",
|
| 14 |
+
label_smoothing=0.0,
|
| 15 |
+
logit_scale=1.0,
|
| 16 |
+
lse_square_scale=0.0,
|
| 17 |
+
inplace_backward=False,
|
| 18 |
+
process_group=None,
|
| 19 |
+
return_z_loss=False,
|
| 20 |
+
):
|
| 21 |
+
"""
|
| 22 |
+
Arguments:
|
| 23 |
+
ignore_index: int. If labels == ignore_index, the loss is set to 0.0.
|
| 24 |
+
label_smoothing: float
|
| 25 |
+
lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss.
|
| 26 |
+
This is also referred to as "z-loss".
|
| 27 |
+
inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits.
|
| 28 |
+
This saves memory.
|
| 29 |
+
process_group: if not None, we're doing Tensor Parallel: each process is responsible for
|
| 30 |
+
one part of the vocab. The loss will be aggregated across processes.
|
| 31 |
+
return_z_loss: bool. If True, we return the component of the loss contributed by
|
| 32 |
+
the lse_square_scale value. This value is only for logging and does not support
|
| 33 |
+
backprop.
|
| 34 |
+
"""
|
| 35 |
+
super().__init__()
|
| 36 |
+
if reduction not in ["mean", "none", "sum"]:
|
| 37 |
+
raise NotImplementedError("Only support reduction = 'mean' or 'none' or 'sum'")
|
| 38 |
+
self.ignore_index = ignore_index
|
| 39 |
+
self.reduction = reduction
|
| 40 |
+
self.label_smoothing = label_smoothing
|
| 41 |
+
self.logit_scale = logit_scale
|
| 42 |
+
self.lse_square_scale = lse_square_scale
|
| 43 |
+
self.inplace_backward = inplace_backward
|
| 44 |
+
self.process_group = process_group
|
| 45 |
+
self.return_z_loss = return_z_loss
|
| 46 |
+
|
| 47 |
+
def forward(self, input, target, precomputed_lse=None):
|
| 48 |
+
"""
|
| 49 |
+
Arguments:
|
| 50 |
+
input: (batch, vocab_size)
|
| 51 |
+
target: (batch,)
|
| 52 |
+
Returns:
|
| 53 |
+
losses: (batch,) if reduction is 'none', else (1,), dtype float
|
| 54 |
+
z_loss: (batch,) if reduction is 'none', else (1,), dtype float (if self.return_z_loss)
|
| 55 |
+
"""
|
| 56 |
+
assert input.is_cuda and target.is_cuda, "Only support CUDA tensors"
|
| 57 |
+
loss, z_loss = cross_entropy_loss(
|
| 58 |
+
input,
|
| 59 |
+
target,
|
| 60 |
+
precomputed_lse=precomputed_lse,
|
| 61 |
+
label_smoothing=self.label_smoothing,
|
| 62 |
+
logit_scale=self.logit_scale,
|
| 63 |
+
lse_square_scale=self.lse_square_scale,
|
| 64 |
+
ignore_index=self.ignore_index,
|
| 65 |
+
inplace_backward=self.inplace_backward,
|
| 66 |
+
process_group=self.process_group,
|
| 67 |
+
)
|
| 68 |
+
if self.reduction == "mean":
|
| 69 |
+
loss = loss.sum() / (target != self.ignore_index).sum()
|
| 70 |
+
elif self.reduction == "sum":
|
| 71 |
+
loss = loss.sum()
|
| 72 |
+
else:
|
| 73 |
+
loss = loss
|
| 74 |
+
|
| 75 |
+
if not self.return_z_loss:
|
| 76 |
+
return loss
|
| 77 |
+
|
| 78 |
+
if self.reduction == "mean":
|
| 79 |
+
z_loss = z_loss.sum() / (target != self.ignore_index).sum()
|
| 80 |
+
elif self.reduction == "sum":
|
| 81 |
+
z_loss = z_loss.sum()
|
| 82 |
+
else:
|
| 83 |
+
z_loss = z_loss
|
| 84 |
+
|
| 85 |
+
return loss, z_loss
|
.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (200 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/baichuan.cpython-311.pyc
ADDED
|
Binary file (7.82 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/bert.cpython-311.pyc
ADDED
|
Binary file (41.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/bigcode.cpython-311.pyc
ADDED
|
Binary file (13.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/btlm.cpython-311.pyc
ADDED
|
Binary file (7.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/falcon.cpython-311.pyc
ADDED
|
Binary file (8.53 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/gpt.cpython-311.pyc
ADDED
|
Binary file (54.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/gpt_neox.cpython-311.pyc
ADDED
|
Binary file (7.87 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/gptj.cpython-311.pyc
ADDED
|
Binary file (7.35 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/llama.cpython-311.pyc
ADDED
|
Binary file (23.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/opt.cpython-311.pyc
ADDED
|
Binary file (7.75 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/vit.cpython-311.pyc
ADDED
|
Binary file (16.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/baichuan.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023, GGGGGGXY, Tri Dao.
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
import json
|
| 5 |
+
import re
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
from collections import OrderedDict
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
|
| 13 |
+
from einops import rearrange
|
| 14 |
+
from transformers import GPT2Config, AutoConfig, PretrainedConfig
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def remap_state_dict_hf_baichuan(state_dict, config):
|
| 18 |
+
def key_mapping_layers(key):
|
| 19 |
+
return re.sub(r"^model.", "transformer.", key)
|
| 20 |
+
|
| 21 |
+
state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
|
| 22 |
+
|
| 23 |
+
# Word embedding
|
| 24 |
+
def key_mapping_emb(key):
|
| 25 |
+
return re.sub(
|
| 26 |
+
r"^transformer.embed_tokens.",
|
| 27 |
+
"transformer.embeddings.word_embeddings.",
|
| 28 |
+
key,
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
|
| 32 |
+
word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight")
|
| 33 |
+
# It's possible that vocab_size is padded to be a multiple of 8, for example.
|
| 34 |
+
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
| 35 |
+
vocab_size = (
|
| 36 |
+
math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple)
|
| 37 |
+
* pad_vocab_size_multiple
|
| 38 |
+
)
|
| 39 |
+
state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
|
| 40 |
+
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
|
| 41 |
+
)
|
| 42 |
+
if getattr(config, "tie_word_embeddings"):
|
| 43 |
+
state_dict["lm_head.weight"] = state_dict[
|
| 44 |
+
"transformer.embeddings.word_embeddings.weight"
|
| 45 |
+
]
|
| 46 |
+
else:
|
| 47 |
+
output_embeddings = state_dict.pop("lm_head.weight")
|
| 48 |
+
# Need to recompute vocab_size since Baichuan shards the word embeddings and output embeddings
|
| 49 |
+
# differently.
|
| 50 |
+
vocab_size = (
|
| 51 |
+
math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple)
|
| 52 |
+
* pad_vocab_size_multiple
|
| 53 |
+
)
|
| 54 |
+
# It's possible that vocab_size is padded to be a multiple of 8, for example.
|
| 55 |
+
state_dict["lm_head.weight"] = F.pad(
|
| 56 |
+
output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
# LayerNorm
|
| 60 |
+
def key_mapping_ln(key):
|
| 61 |
+
key = re.sub(r"^transformer.norm.", r"transformer.ln_f.", key)
|
| 62 |
+
key = re.sub(
|
| 63 |
+
r"^transformer.layers.(\d+).input_layernorm.",
|
| 64 |
+
r"transformer.layers.\1.norm1.",
|
| 65 |
+
key,
|
| 66 |
+
)
|
| 67 |
+
key = re.sub(
|
| 68 |
+
r"^transformer.layers.(\d+).post_attention_layernorm.",
|
| 69 |
+
r"transformer.layers.\1.norm2.",
|
| 70 |
+
key,
|
| 71 |
+
)
|
| 72 |
+
return key
|
| 73 |
+
|
| 74 |
+
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
|
| 75 |
+
|
| 76 |
+
# MLP
|
| 77 |
+
for l in range(config.n_layer):
|
| 78 |
+
w1 = state_dict.pop(f"transformer.layers.{l}.mlp.gate_proj.weight")
|
| 79 |
+
w3 = state_dict.pop(f"transformer.layers.{l}.mlp.up_proj.weight")
|
| 80 |
+
# Our ordering is different
|
| 81 |
+
state_dict[f"transformer.layers.{l}.mlp.fc1.weight"] = torch.cat(
|
| 82 |
+
[w3, w1], dim=0
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
def key_mapping_mlp(key):
|
| 86 |
+
return re.sub(
|
| 87 |
+
r"^transformer.layers.(\d+).mlp.down_proj.",
|
| 88 |
+
r"transformer.layers.\1.mlp.fc2.",
|
| 89 |
+
key,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
|
| 93 |
+
|
| 94 |
+
# Attention
|
| 95 |
+
def key_mapping_attn(key):
|
| 96 |
+
key = re.sub(
|
| 97 |
+
r"^transformer.layers.(\d+).self_attn.W_pack.",
|
| 98 |
+
r"transformer.layers.\1.mixer.Wqkv.",
|
| 99 |
+
key,
|
| 100 |
+
)
|
| 101 |
+
key = re.sub(
|
| 102 |
+
r"^transformer.layers.(\d+).self_attn.o_proj.",
|
| 103 |
+
r"transformer.layers.\1.mixer.out_proj.",
|
| 104 |
+
key,
|
| 105 |
+
)
|
| 106 |
+
return key
|
| 107 |
+
|
| 108 |
+
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
|
| 109 |
+
for l in range(config.n_layer):
|
| 110 |
+
# pop rotary_emb.inv_freq from state dict
|
| 111 |
+
state_dict.pop(f"transformer.layers.{l}.self_attn.rotary_emb.inv_freq", None)
|
| 112 |
+
return state_dict
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def baichuan_config_to_gpt2_config(baichuan_config: PretrainedConfig) -> GPT2Config:
|
| 116 |
+
# HACK: the config doesn't have say whether it's rotary or alibi.
|
| 117 |
+
# So we have to infer from the hidden size (7B -> rotary, 13B -> alibi).
|
| 118 |
+
# HACK: the config doesn't have say whether it uses norm head.
|
| 119 |
+
# So we have to infer from the vocab size
|
| 120 |
+
# (v1, vocab size 64k, no norm head; v2, vocab size 128k, norm head).
|
| 121 |
+
use_rotary = baichuan_config.hidden_size < 5000
|
| 122 |
+
return GPT2Config(
|
| 123 |
+
vocab_size=baichuan_config.vocab_size,
|
| 124 |
+
n_positions=0, # No absolute position embedding
|
| 125 |
+
n_embd=baichuan_config.hidden_size,
|
| 126 |
+
n_layer=baichuan_config.num_hidden_layers,
|
| 127 |
+
n_head=baichuan_config.num_attention_heads,
|
| 128 |
+
n_inner=baichuan_config.intermediate_size,
|
| 129 |
+
activation_function="swiglu", # Hardcode since HF calls it 'silu'
|
| 130 |
+
# baichuan doesn't have dropout, idk if it's because they only release the inference code
|
| 131 |
+
resid_pdrop=0.0,
|
| 132 |
+
embd_pdrop=0.0,
|
| 133 |
+
attn_pdrop=0.0,
|
| 134 |
+
layer_norm_epsilon=baichuan_config.rms_norm_eps,
|
| 135 |
+
initializer_range=baichuan_config.initializer_range,
|
| 136 |
+
bos_token_id=baichuan_config.bos_token_id,
|
| 137 |
+
eos_token_id=baichuan_config.eos_token_id,
|
| 138 |
+
# These are new arguments not in the original GPT2Config
|
| 139 |
+
pad_token_id=baichuan_config.pad_token_id, # Idk if this does anything
|
| 140 |
+
rms_norm=True,
|
| 141 |
+
rotary_emb_fraction=1.0 if use_rotary else 0.0,
|
| 142 |
+
rotary_emb_interleaved=False,
|
| 143 |
+
use_alibi=not use_rotary,
|
| 144 |
+
use_flash_attn=not use_rotary, # Alibi code path requires flash_attn
|
| 145 |
+
tie_word_embeddings=False,
|
| 146 |
+
norm_head=baichuan_config.vocab_size > 70000,
|
| 147 |
+
qkv_proj_bias=False,
|
| 148 |
+
out_proj_bias=False,
|
| 149 |
+
mlp_fc1_bias=False,
|
| 150 |
+
mlp_fc2_bias=False,
|
| 151 |
+
)
|
.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/bert.py
ADDED
|
@@ -0,0 +1,764 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022, Tri Dao.
|
| 2 |
+
# This BERT implementation is based on our MLPerf 2.0 and MLPerf 2.1 BERT implementation.
|
| 3 |
+
# https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
|
| 4 |
+
# https://github.com/mlcommons/training_results_v2.1/blob/main/Azure-HazyResearch/benchmarks/bert/implementations/ND96amsr_A100_v4/modeling.py
|
| 5 |
+
|
| 6 |
+
# Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
|
| 7 |
+
|
| 8 |
+
import logging
|
| 9 |
+
import re
|
| 10 |
+
from collections import OrderedDict
|
| 11 |
+
from collections.abc import Sequence
|
| 12 |
+
from functools import partial
|
| 13 |
+
from typing import Any, Mapping
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
from einops import rearrange
|
| 19 |
+
from transformers import BertConfig, PretrainedConfig
|
| 20 |
+
from transformers.models.bert.modeling_bert import (
|
| 21 |
+
BaseModelOutputWithPoolingAndCrossAttentions,
|
| 22 |
+
BertForPreTrainingOutput,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
from flash_attn.bert_padding import (
|
| 26 |
+
index_first_axis,
|
| 27 |
+
index_first_axis_residual,
|
| 28 |
+
pad_input,
|
| 29 |
+
unpad_input,
|
| 30 |
+
)
|
| 31 |
+
from flash_attn.modules.block import Block
|
| 32 |
+
from flash_attn.modules.embedding import BertEmbeddings
|
| 33 |
+
from flash_attn.modules.mha import MHA
|
| 34 |
+
from flash_attn.modules.mlp import FusedMLP, Mlp
|
| 35 |
+
from flash_attn.utils.pretrained import state_dict_from_pretrained
|
| 36 |
+
|
| 37 |
+
try:
|
| 38 |
+
from flash_attn.ops.fused_dense import FusedDense
|
| 39 |
+
except ImportError:
|
| 40 |
+
FusedDense = None
|
| 41 |
+
|
| 42 |
+
try:
|
| 43 |
+
from flash_attn.ops.triton.layer_norm import layer_norm_fn
|
| 44 |
+
except ImportError:
|
| 45 |
+
layer_norm_fn = None
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
try:
|
| 49 |
+
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
| 50 |
+
except ImportError:
|
| 51 |
+
CrossEntropyLoss = None
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
logger = logging.getLogger(__name__)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def create_mixer_cls(config, cross_attn=False, return_residual=False):
|
| 58 |
+
use_flash_attn = getattr(config, "use_flash_attn", False)
|
| 59 |
+
fused_bias_fc = getattr(config, "fused_bias_fc", False)
|
| 60 |
+
rotary_kwargs = {}
|
| 61 |
+
if config.position_embedding_type == "rotary":
|
| 62 |
+
rotary_kwargs["rotary_emb_dim"] = getattr(config, "rotary_emb_dim", config.hidden_size)
|
| 63 |
+
rotary_kwargs["rotary_emb_base"] = getattr(config, "rotary_emb_base", 10000.0)
|
| 64 |
+
rotary_kwargs["rotary_emb_scale_base"] = getattr(config, "rotary_emb_scale_base", None)
|
| 65 |
+
rotary_kwargs["rotary_emb_interleaved"] = getattr(config, "rotary_emb_interleaved", False)
|
| 66 |
+
mixer_cls = partial(
|
| 67 |
+
MHA,
|
| 68 |
+
num_heads=config.num_attention_heads,
|
| 69 |
+
cross_attn=cross_attn,
|
| 70 |
+
dropout=config.attention_probs_dropout_prob,
|
| 71 |
+
causal=False,
|
| 72 |
+
fused_bias_fc=fused_bias_fc,
|
| 73 |
+
use_flash_attn=use_flash_attn,
|
| 74 |
+
return_residual=return_residual,
|
| 75 |
+
**rotary_kwargs,
|
| 76 |
+
)
|
| 77 |
+
return mixer_cls
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def create_mlp_cls(config, layer_idx=None, return_residual=False):
|
| 81 |
+
inner_dim = config.intermediate_size
|
| 82 |
+
fused_mlp = getattr(config, "fused_mlp", False)
|
| 83 |
+
if fused_mlp:
|
| 84 |
+
assert config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"], (
|
| 85 |
+
"fused_mlp only " "supports approximate gelu"
|
| 86 |
+
)
|
| 87 |
+
if not fused_mlp:
|
| 88 |
+
approximate = (
|
| 89 |
+
"tanh"
|
| 90 |
+
if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
|
| 91 |
+
else "none"
|
| 92 |
+
)
|
| 93 |
+
mlp_cls = partial(
|
| 94 |
+
Mlp,
|
| 95 |
+
hidden_features=inner_dim,
|
| 96 |
+
activation=partial(F.gelu, approximate=approximate),
|
| 97 |
+
return_residual=return_residual,
|
| 98 |
+
)
|
| 99 |
+
else:
|
| 100 |
+
if FusedMLP is None:
|
| 101 |
+
raise ImportError("fused_dense is not installed")
|
| 102 |
+
mlp_checkpoint_lvl = getattr(config, "mlp_checkpoint_lvl", 0)
|
| 103 |
+
# mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
|
| 104 |
+
if isinstance(mlp_checkpoint_lvl, Sequence):
|
| 105 |
+
assert layer_idx is not None
|
| 106 |
+
mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
|
| 107 |
+
mlp_cls = partial(
|
| 108 |
+
FusedMLP,
|
| 109 |
+
hidden_features=inner_dim,
|
| 110 |
+
checkpoint_lvl=mlp_checkpoint_lvl,
|
| 111 |
+
return_residual=return_residual,
|
| 112 |
+
)
|
| 113 |
+
return mlp_cls
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def create_block(config, layer_idx=None):
|
| 117 |
+
last_layer_subset = getattr(config, "last_layer_subset", False)
|
| 118 |
+
cross_attn = last_layer_subset and layer_idx == config.num_hidden_layers - 1
|
| 119 |
+
# TD [2022-12-19]: For cross attention (last layer), we actually want to return the
|
| 120 |
+
# residual x_kv, not residual x. But it's annoying to change the API (and it only affects
|
| 121 |
+
# one layer) so we just choose not to return residual in this case.
|
| 122 |
+
return_residual = not cross_attn
|
| 123 |
+
mixer_cls = create_mixer_cls(config, cross_attn, return_residual=return_residual)
|
| 124 |
+
mlp_cls = create_mlp_cls(config, layer_idx, return_residual=return_residual)
|
| 125 |
+
norm_cls = partial(nn.LayerNorm, eps=config.layer_norm_eps)
|
| 126 |
+
block = Block(
|
| 127 |
+
config.hidden_size,
|
| 128 |
+
mixer_cls,
|
| 129 |
+
mlp_cls,
|
| 130 |
+
norm_cls=norm_cls,
|
| 131 |
+
prenorm=False,
|
| 132 |
+
resid_dropout1=config.hidden_dropout_prob,
|
| 133 |
+
resid_dropout2=config.hidden_dropout_prob,
|
| 134 |
+
fused_dropout_add_ln=getattr(config, "fused_dropout_add_ln", False),
|
| 135 |
+
return_residual=return_residual,
|
| 136 |
+
)
|
| 137 |
+
return block
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
# https://github.com/huggingface/transformers/blob/7032e0203262ebb2ebf55da8d2e01f873973e835/src/transformers/models/bert/modeling_bert.py#L748
|
| 141 |
+
def _init_weights(module, initializer_range=0.02):
|
| 142 |
+
if isinstance(module, nn.Linear):
|
| 143 |
+
nn.init.normal_(module.weight, std=initializer_range)
|
| 144 |
+
if module.bias is not None:
|
| 145 |
+
nn.init.zeros_(module.bias)
|
| 146 |
+
elif isinstance(module, nn.Embedding):
|
| 147 |
+
nn.init.normal_(module.weight, std=initializer_range)
|
| 148 |
+
if module.padding_idx is not None:
|
| 149 |
+
nn.init.zeros_(module.weight[module.padding_idx])
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class BertEncoder(nn.Module):
|
| 153 |
+
def __init__(self, config: BertConfig):
|
| 154 |
+
super().__init__()
|
| 155 |
+
self.use_flash_attn = getattr(config, "use_flash_attn", False)
|
| 156 |
+
self.layers = nn.ModuleList(
|
| 157 |
+
[create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
def forward(self, hidden_states, key_padding_mask=None, subset_mask=None):
|
| 161 |
+
"""If subset_mask is not None, we only want output for the subset of the sequence.
|
| 162 |
+
This means that we only compute the last layer output for these tokens.
|
| 163 |
+
subset_mask: (batch, seqlen), dtype=torch.bool
|
| 164 |
+
"""
|
| 165 |
+
if key_padding_mask is None or not self.use_flash_attn:
|
| 166 |
+
mixer_kwargs = (
|
| 167 |
+
{"key_padding_mask": key_padding_mask} if key_padding_mask is not None else None
|
| 168 |
+
)
|
| 169 |
+
for layer in self.layers:
|
| 170 |
+
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
| 171 |
+
if subset_mask is not None:
|
| 172 |
+
hidden_states = hidden_states[subset_mask]
|
| 173 |
+
else:
|
| 174 |
+
batch, seqlen = hidden_states.shape[:2]
|
| 175 |
+
hidden_states, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
|
| 176 |
+
hidden_states, key_padding_mask
|
| 177 |
+
)
|
| 178 |
+
mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
|
| 179 |
+
if subset_mask is None:
|
| 180 |
+
for layer in self.layers:
|
| 181 |
+
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
| 182 |
+
hidden_states = pad_input(hidden_states, indices, batch, seqlen)
|
| 183 |
+
else:
|
| 184 |
+
for layer in self.layers[:-1]:
|
| 185 |
+
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
| 186 |
+
if key_padding_mask is not None:
|
| 187 |
+
subset_idx = torch.nonzero(
|
| 188 |
+
subset_mask[key_padding_mask], as_tuple=False
|
| 189 |
+
).flatten()
|
| 190 |
+
subset_seqlens = (subset_mask & key_padding_mask).sum(dim=-1, dtype=torch.int32)
|
| 191 |
+
subset_cu_seqlens = F.pad(
|
| 192 |
+
torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32), (1, 0)
|
| 193 |
+
)
|
| 194 |
+
else:
|
| 195 |
+
subset_idx = torch.nonzero(subset_mask, as_tuple=False).flatten()
|
| 196 |
+
subset_seqlens = subset_mask.sum(dim=-1, dtype=torch.int32)
|
| 197 |
+
subset_cu_seqlens = F.pad(
|
| 198 |
+
torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32), (1, 0)
|
| 199 |
+
)
|
| 200 |
+
hidden_states_subset, hidden_states = index_first_axis_residual(
|
| 201 |
+
hidden_states, subset_idx
|
| 202 |
+
)
|
| 203 |
+
# It's ok to set max_seqlen_q to be much larger
|
| 204 |
+
mixer_kwargs = {
|
| 205 |
+
"x_kv": hidden_states,
|
| 206 |
+
"cu_seqlens": subset_cu_seqlens,
|
| 207 |
+
"max_seqlen": max_seqlen_in_batch,
|
| 208 |
+
"cu_seqlens_k": cu_seqlens,
|
| 209 |
+
"max_seqlen_k": max_seqlen_in_batch,
|
| 210 |
+
}
|
| 211 |
+
hidden_states = self.layers[-1](hidden_states_subset, mixer_kwargs=mixer_kwargs)
|
| 212 |
+
return hidden_states
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
class BertPooler(nn.Module):
|
| 216 |
+
def __init__(self, config):
|
| 217 |
+
super().__init__()
|
| 218 |
+
fused_bias_fc = getattr(config, "fused_bias_fc", False)
|
| 219 |
+
if fused_bias_fc and FusedDense is None:
|
| 220 |
+
raise ImportError("fused_dense is not installed")
|
| 221 |
+
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
|
| 222 |
+
self.dense = linear_cls(config.hidden_size, config.hidden_size)
|
| 223 |
+
self.activation = nn.Tanh()
|
| 224 |
+
|
| 225 |
+
def forward(self, hidden_states, pool=True):
|
| 226 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
| 227 |
+
# to the first token.
|
| 228 |
+
first_token_tensor = hidden_states[:, 0] if pool else hidden_states
|
| 229 |
+
pooled_output = self.dense(first_token_tensor)
|
| 230 |
+
pooled_output = self.activation(pooled_output)
|
| 231 |
+
return pooled_output
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
class BertPredictionHeadTransform(nn.Module):
|
| 235 |
+
def __init__(self, config):
|
| 236 |
+
super().__init__()
|
| 237 |
+
fused_bias_fc = getattr(config, "fused_bias_fc", False)
|
| 238 |
+
if fused_bias_fc and FusedDense is None:
|
| 239 |
+
raise ImportError("fused_dense is not installed")
|
| 240 |
+
self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
|
| 241 |
+
if self.fused_dropout_add_ln and layer_norm_fn is None:
|
| 242 |
+
raise ImportError("Triton is not installed")
|
| 243 |
+
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
|
| 244 |
+
self.dense = linear_cls(config.hidden_size, config.hidden_size)
|
| 245 |
+
approximate = (
|
| 246 |
+
"tanh"
|
| 247 |
+
if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
|
| 248 |
+
else "none"
|
| 249 |
+
)
|
| 250 |
+
self.transform_act_fn = nn.GELU(approximate=approximate)
|
| 251 |
+
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 252 |
+
|
| 253 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 254 |
+
hidden_states = self.dense(hidden_states)
|
| 255 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
| 256 |
+
if not self.fused_dropout_add_ln:
|
| 257 |
+
hidden_states = self.layer_norm(hidden_states)
|
| 258 |
+
else:
|
| 259 |
+
hidden_states = layer_norm_fn(
|
| 260 |
+
hidden_states, self.layer_norm.weight, self.layer_norm.bias, eps=self.layer_norm.eps
|
| 261 |
+
)
|
| 262 |
+
return hidden_states
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
class BertLMPredictionHead(nn.Module):
|
| 266 |
+
def __init__(self, config):
|
| 267 |
+
super().__init__()
|
| 268 |
+
fused_bias_fc = getattr(config, "fused_bias_fc", False)
|
| 269 |
+
if fused_bias_fc and FusedDense is None:
|
| 270 |
+
raise ImportError("fused_dense is not installed")
|
| 271 |
+
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
|
| 272 |
+
|
| 273 |
+
self.transform = BertPredictionHeadTransform(config)
|
| 274 |
+
|
| 275 |
+
# The output weights are the same as the input embeddings, but there is
|
| 276 |
+
# an output-only bias for each token.
|
| 277 |
+
self.decoder = linear_cls(config.hidden_size, config.vocab_size, bias=True)
|
| 278 |
+
|
| 279 |
+
def forward(self, hidden_states):
|
| 280 |
+
hidden_states = self.transform(hidden_states)
|
| 281 |
+
hidden_states = self.decoder(hidden_states)
|
| 282 |
+
return hidden_states
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
class BertPreTrainingHeads(nn.Module):
|
| 286 |
+
def __init__(self, config):
|
| 287 |
+
super().__init__()
|
| 288 |
+
self.predictions = BertLMPredictionHead(config)
|
| 289 |
+
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
| 290 |
+
|
| 291 |
+
def forward(self, sequence_output, pooled_output):
|
| 292 |
+
prediction_scores = self.predictions(sequence_output)
|
| 293 |
+
seq_relationship_score = self.seq_relationship(pooled_output)
|
| 294 |
+
return prediction_scores, seq_relationship_score
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
class BertPreTrainedModel(nn.Module):
|
| 298 |
+
"""An abstract class to handle weights initialization and
|
| 299 |
+
a simple interface for dowloading and loading pretrained models.
|
| 300 |
+
"""
|
| 301 |
+
|
| 302 |
+
def __init__(self, config, *inputs, **kwargs):
|
| 303 |
+
super().__init__()
|
| 304 |
+
if not isinstance(config, BertConfig):
|
| 305 |
+
raise ValueError(
|
| 306 |
+
"Parameter config in `{}(config)` should be an instance of class `BertConfig`. "
|
| 307 |
+
"To create a model from a Google pretrained model use "
|
| 308 |
+
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
|
| 309 |
+
self.__class__.__name__, self.__class__.__name__
|
| 310 |
+
)
|
| 311 |
+
)
|
| 312 |
+
self.config = config
|
| 313 |
+
|
| 314 |
+
@classmethod
|
| 315 |
+
def from_pretrained(cls, model_name, config, *inputs, **kwargs):
|
| 316 |
+
"""
|
| 317 |
+
Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict.
|
| 318 |
+
Download and cache the pre-trained model file if needed.
|
| 319 |
+
|
| 320 |
+
Params:
|
| 321 |
+
pretrained_model_name_or_path: either:
|
| 322 |
+
- a path or url to a pretrained model archive containing:
|
| 323 |
+
. `bert_config.json` a configuration file for the model
|
| 324 |
+
. `pytorch_model.bin` a PyTorch dump of a BertForPretraining instance
|
| 325 |
+
- a path or url to a pretrained model archive containing:
|
| 326 |
+
. `bert_config.json` a configuration file for the model
|
| 327 |
+
. `model.chkpt` a TensorFlow checkpoint
|
| 328 |
+
*inputs, **kwargs: additional input for the specific Bert class
|
| 329 |
+
(ex: num_labels for BertForSequenceClassification)
|
| 330 |
+
"""
|
| 331 |
+
# Instantiate model.
|
| 332 |
+
model = cls(config, *inputs, **kwargs)
|
| 333 |
+
load_return = model.load_state_dict(
|
| 334 |
+
remap_state_dict(state_dict_from_pretrained(model_name), config), strict=False
|
| 335 |
+
)
|
| 336 |
+
logger.info(load_return)
|
| 337 |
+
return model
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
class BertModel(BertPreTrainedModel):
|
| 341 |
+
def __init__(self, config: BertConfig, add_pooling_layer=True):
|
| 342 |
+
super().__init__(config)
|
| 343 |
+
self.pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
| 344 |
+
if config.vocab_size % self.pad_vocab_size_multiple != 0:
|
| 345 |
+
config.vocab_size += self.pad_vocab_size_multiple - (
|
| 346 |
+
config.vocab_size % self.pad_vocab_size_multiple
|
| 347 |
+
)
|
| 348 |
+
self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
|
| 349 |
+
if self.fused_dropout_add_ln and layer_norm_fn is None:
|
| 350 |
+
raise ImportError("Triton is not installed")
|
| 351 |
+
assert config.hidden_act in ["gelu", "gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
|
| 352 |
+
|
| 353 |
+
self.embeddings = BertEmbeddings(
|
| 354 |
+
config.hidden_size,
|
| 355 |
+
config.vocab_size,
|
| 356 |
+
config.max_position_embeddings,
|
| 357 |
+
config.type_vocab_size,
|
| 358 |
+
padding_idx=config.pad_token_id,
|
| 359 |
+
)
|
| 360 |
+
self.emb_drop = nn.Dropout(config.hidden_dropout_prob)
|
| 361 |
+
self.emb_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 362 |
+
self.encoder = BertEncoder(config)
|
| 363 |
+
self.pooler = BertPooler(config) if add_pooling_layer else None
|
| 364 |
+
|
| 365 |
+
self.apply(partial(_init_weights, initializer_range=config.initializer_range))
|
| 366 |
+
|
| 367 |
+
def forward(
|
| 368 |
+
self,
|
| 369 |
+
input_ids,
|
| 370 |
+
position_ids=None,
|
| 371 |
+
token_type_ids=None,
|
| 372 |
+
attention_mask=None,
|
| 373 |
+
masked_tokens_mask=None,
|
| 374 |
+
):
|
| 375 |
+
"""If masked_tokens_mask is not None (i.e. last_layer_subset == True in BertForPreTraining),
|
| 376 |
+
we only want the output for the masked tokens. This means that we only compute the last
|
| 377 |
+
layer output for these tokens.
|
| 378 |
+
masked_tokens_mask: (batch, seqlen), dtype=torch.bool
|
| 379 |
+
"""
|
| 380 |
+
hidden_states = self.embeddings(
|
| 381 |
+
input_ids, position_ids=position_ids, token_type_ids=token_type_ids
|
| 382 |
+
)
|
| 383 |
+
# TD [2022-12:18]: Don't need to force residual in fp32
|
| 384 |
+
# BERT puts embedding LayerNorm before embedding dropout.
|
| 385 |
+
if not self.fused_dropout_add_ln:
|
| 386 |
+
hidden_states = self.emb_ln(hidden_states)
|
| 387 |
+
else:
|
| 388 |
+
hidden_states = layer_norm_fn(
|
| 389 |
+
hidden_states, self.emb_ln.weight, self.emb_ln.bias, eps=self.emb_ln.eps
|
| 390 |
+
)
|
| 391 |
+
hidden_states = self.emb_drop(hidden_states)
|
| 392 |
+
|
| 393 |
+
if masked_tokens_mask is not None:
|
| 394 |
+
batch_size, seqlen = input_ids.shape[:2]
|
| 395 |
+
# We also need the first column for the CLS token
|
| 396 |
+
first_col_mask = torch.zeros(
|
| 397 |
+
batch_size, seqlen, dtype=torch.bool, device=input_ids.device
|
| 398 |
+
)
|
| 399 |
+
first_col_mask[:, 0] = True
|
| 400 |
+
subset_mask = masked_tokens_mask | first_col_mask
|
| 401 |
+
else:
|
| 402 |
+
subset_mask = None
|
| 403 |
+
|
| 404 |
+
sequence_output = self.encoder(
|
| 405 |
+
hidden_states, key_padding_mask=attention_mask, subset_mask=subset_mask
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
if masked_tokens_mask is None:
|
| 409 |
+
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
| 410 |
+
else:
|
| 411 |
+
# TD [2022-03-01]: the indexing here is very tricky.
|
| 412 |
+
if attention_mask is not None:
|
| 413 |
+
subset_idx = subset_mask[attention_mask]
|
| 414 |
+
pool_input = sequence_output[first_col_mask[attention_mask][subset_idx]]
|
| 415 |
+
sequence_output = sequence_output[masked_tokens_mask[attention_mask][subset_idx]]
|
| 416 |
+
else:
|
| 417 |
+
pool_input = sequence_output[first_col_mask[subset_mask]]
|
| 418 |
+
sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
|
| 419 |
+
pooled_output = self.pooler(pool_input, pool=False) if self.pooler is not None else None
|
| 420 |
+
|
| 421 |
+
return BaseModelOutputWithPoolingAndCrossAttentions(
|
| 422 |
+
last_hidden_state=sequence_output,
|
| 423 |
+
pooler_output=pooled_output,
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
class BertForPreTraining(BertPreTrainedModel):
|
| 428 |
+
def __init__(self, config: BertConfig):
|
| 429 |
+
super().__init__(config)
|
| 430 |
+
# If dense_seq_output, we only need to pass the hidden states for the masked out tokens
|
| 431 |
+
# (around 15%) to the classifier heads.
|
| 432 |
+
self.dense_seq_output = getattr(config, "dense_seq_output", False)
|
| 433 |
+
# If last_layer_subset, we only need the compute the last layer for a subset of tokens
|
| 434 |
+
# (e.g., the tokens we need to compute the masked LM loss and the next-sentence prediction).
|
| 435 |
+
self.last_layer_subset = getattr(config, "last_layer_subset", False)
|
| 436 |
+
if self.last_layer_subset:
|
| 437 |
+
assert self.dense_seq_output, "last_layer_subset requires dense_seq_output"
|
| 438 |
+
use_xentropy = getattr(config, "use_xentropy", False)
|
| 439 |
+
if use_xentropy and CrossEntropyLoss is None:
|
| 440 |
+
raise ImportError("xentropy_cuda is not installed")
|
| 441 |
+
loss_cls = (
|
| 442 |
+
nn.CrossEntropyLoss
|
| 443 |
+
if not use_xentropy
|
| 444 |
+
else partial(CrossEntropyLoss, inplace_backward=True)
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
self.bert = BertModel(config)
|
| 448 |
+
self.cls = BertPreTrainingHeads(config)
|
| 449 |
+
self.mlm_loss = loss_cls(ignore_index=0)
|
| 450 |
+
self.nsp_loss = loss_cls(ignore_index=-1)
|
| 451 |
+
|
| 452 |
+
# Initialize weights and apply final processing
|
| 453 |
+
self.apply(partial(_init_weights, initializer_range=config.initializer_range))
|
| 454 |
+
self.tie_weights()
|
| 455 |
+
|
| 456 |
+
def tie_weights(self):
|
| 457 |
+
self.cls.predictions.decoder.weight = self.bert.embeddings.word_embeddings.weight
|
| 458 |
+
|
| 459 |
+
def forward(
|
| 460 |
+
self,
|
| 461 |
+
input_ids,
|
| 462 |
+
position_ids=None,
|
| 463 |
+
token_type_ids=None,
|
| 464 |
+
attention_mask=None,
|
| 465 |
+
labels=None,
|
| 466 |
+
next_sentence_label=None,
|
| 467 |
+
):
|
| 468 |
+
"""
|
| 469 |
+
If labels are provided, they must be 0 for masked out tokens (as specified in the attention
|
| 470 |
+
mask).
|
| 471 |
+
Outputs:
|
| 472 |
+
if `labels` and `next_sentence_label` are not `None`:
|
| 473 |
+
Outputs the total_loss which is the sum of the masked language modeling loss and the next
|
| 474 |
+
sentence classification loss.
|
| 475 |
+
if `labels` or `next_sentence_label` is `None`:
|
| 476 |
+
Outputs a tuple comprising
|
| 477 |
+
- the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and
|
| 478 |
+
- the next sentence classification logits of shape [batch_size, 2].
|
| 479 |
+
|
| 480 |
+
"""
|
| 481 |
+
masked_tokens_mask = labels > 0 if (self.last_layer_subset and labels is not None) else None
|
| 482 |
+
outputs = self.bert(
|
| 483 |
+
input_ids,
|
| 484 |
+
position_ids=position_ids,
|
| 485 |
+
token_type_ids=token_type_ids,
|
| 486 |
+
attention_mask=attention_mask.bool() if attention_mask is not None else None,
|
| 487 |
+
masked_tokens_mask=masked_tokens_mask,
|
| 488 |
+
)
|
| 489 |
+
sequence_output, pooled_output = outputs.last_hidden_state, outputs.pooler_output
|
| 490 |
+
if self.dense_seq_output and labels is not None:
|
| 491 |
+
masked_token_idx = torch.nonzero(labels.flatten() > 0, as_tuple=False).flatten()
|
| 492 |
+
if not self.last_layer_subset:
|
| 493 |
+
sequence_output = index_first_axis(
|
| 494 |
+
rearrange(sequence_output, "b s d -> (b s) d"), masked_token_idx
|
| 495 |
+
)
|
| 496 |
+
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
|
| 497 |
+
|
| 498 |
+
total_loss = None
|
| 499 |
+
if labels is not None and next_sentence_label is not None:
|
| 500 |
+
if (
|
| 501 |
+
self.dense_seq_output and labels is not None
|
| 502 |
+
): # prediction_scores are already flattened
|
| 503 |
+
masked_lm_loss = self.mlm_loss(
|
| 504 |
+
prediction_scores, labels.flatten()[masked_token_idx]
|
| 505 |
+
)
|
| 506 |
+
else:
|
| 507 |
+
masked_lm_loss = self.mlm_loss(
|
| 508 |
+
rearrange(prediction_scores, "... v -> (...) v"),
|
| 509 |
+
rearrange(labels, "... -> (...)"),
|
| 510 |
+
)
|
| 511 |
+
next_sentence_loss = self.nsp_loss(
|
| 512 |
+
rearrange(seq_relationship_score, "... t -> (...) t"),
|
| 513 |
+
rearrange(next_sentence_label, "... -> (...)"),
|
| 514 |
+
)
|
| 515 |
+
total_loss = masked_lm_loss.float() + next_sentence_loss.float()
|
| 516 |
+
|
| 517 |
+
return BertForPreTrainingOutput(
|
| 518 |
+
loss=total_loss,
|
| 519 |
+
prediction_logits=prediction_scores,
|
| 520 |
+
seq_relationship_logits=seq_relationship_score,
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
def remap_state_dict(state_dict, config: PretrainedConfig):
|
| 525 |
+
"""
|
| 526 |
+
Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
|
| 527 |
+
"""
|
| 528 |
+
|
| 529 |
+
# LayerNorm
|
| 530 |
+
def key_mapping_ln_gamma_beta(key):
|
| 531 |
+
key = re.sub(r"LayerNorm.gamma$", "LayerNorm.weight", key)
|
| 532 |
+
key = re.sub(r"LayerNorm.beta$", "LayerNorm.bias", key)
|
| 533 |
+
return key
|
| 534 |
+
|
| 535 |
+
state_dict = OrderedDict((key_mapping_ln_gamma_beta(k), v) for k, v in state_dict.items())
|
| 536 |
+
|
| 537 |
+
# Layers
|
| 538 |
+
def key_mapping_layers(key):
|
| 539 |
+
return re.sub(r"^bert.encoder.layer.", "bert.encoder.layers.", key)
|
| 540 |
+
|
| 541 |
+
state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
|
| 542 |
+
|
| 543 |
+
# LayerNorm
|
| 544 |
+
def key_mapping_ln(key):
|
| 545 |
+
key = re.sub(r"^bert.embeddings.LayerNorm.", "bert.emb_ln.", key)
|
| 546 |
+
key = re.sub(
|
| 547 |
+
r"^bert.encoder.layers.(\d+).attention.output.LayerNorm.(weight|bias)",
|
| 548 |
+
r"bert.encoder.layers.\1.norm1.\2",
|
| 549 |
+
key,
|
| 550 |
+
)
|
| 551 |
+
key = re.sub(
|
| 552 |
+
r"^bert.encoder.layers.(\d+).output.LayerNorm.(weight|bias)",
|
| 553 |
+
r"bert.encoder.layers.\1.norm2.\2",
|
| 554 |
+
key,
|
| 555 |
+
)
|
| 556 |
+
key = re.sub(
|
| 557 |
+
r"^cls.predictions.transform.LayerNorm.(weight|bias)",
|
| 558 |
+
r"cls.predictions.transform.layer_norm.\1",
|
| 559 |
+
key,
|
| 560 |
+
)
|
| 561 |
+
return key
|
| 562 |
+
|
| 563 |
+
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
|
| 564 |
+
|
| 565 |
+
# MLP
|
| 566 |
+
def key_mapping_mlp(key):
|
| 567 |
+
key = re.sub(
|
| 568 |
+
r"^bert.encoder.layers.(\d+).intermediate.dense.(weight|bias)",
|
| 569 |
+
r"bert.encoder.layers.\1.mlp.fc1.\2",
|
| 570 |
+
key,
|
| 571 |
+
)
|
| 572 |
+
key = re.sub(
|
| 573 |
+
r"^bert.encoder.layers.(\d+).output.dense.(weight|bias)",
|
| 574 |
+
r"bert.encoder.layers.\1.mlp.fc2.\2",
|
| 575 |
+
key,
|
| 576 |
+
)
|
| 577 |
+
return key
|
| 578 |
+
|
| 579 |
+
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
|
| 580 |
+
|
| 581 |
+
# Attention
|
| 582 |
+
last_layer_subset = getattr(config, "last_layer_subset", False)
|
| 583 |
+
for d in range(config.num_hidden_layers):
|
| 584 |
+
Wq = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.query.weight")
|
| 585 |
+
Wk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.weight")
|
| 586 |
+
Wv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.weight")
|
| 587 |
+
bq = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.query.bias")
|
| 588 |
+
bk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.bias")
|
| 589 |
+
bv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.bias")
|
| 590 |
+
if not (last_layer_subset and d == config.num_hidden_layers - 1):
|
| 591 |
+
state_dict[f"bert.encoder.layers.{d}.mixer.Wqkv.weight"] = torch.cat(
|
| 592 |
+
[Wq, Wk, Wv], dim=0
|
| 593 |
+
)
|
| 594 |
+
state_dict[f"bert.encoder.layers.{d}.mixer.Wqkv.bias"] = torch.cat([bq, bk, bv], dim=0)
|
| 595 |
+
else:
|
| 596 |
+
state_dict[f"bert.encoder.layers.{d}.mixer.Wq.weight"] = Wq
|
| 597 |
+
state_dict[f"bert.encoder.layers.{d}.mixer.Wkv.weight"] = torch.cat([Wk, Wv], dim=0)
|
| 598 |
+
state_dict[f"bert.encoder.layers.{d}.mixer.Wq.bias"] = bq
|
| 599 |
+
state_dict[f"bert.encoder.layers.{d}.mixer.Wkv.bias"] = torch.cat([bk, bv], dim=0)
|
| 600 |
+
|
| 601 |
+
def key_mapping_attn(key):
|
| 602 |
+
return re.sub(
|
| 603 |
+
r"^bert.encoder.layers.(\d+).attention.output.dense.(weight|bias)",
|
| 604 |
+
r"bert.encoder.layers.\1.mixer.out_proj.\2",
|
| 605 |
+
key,
|
| 606 |
+
)
|
| 607 |
+
|
| 608 |
+
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
|
| 609 |
+
|
| 610 |
+
def key_mapping_decoder_bias(key):
|
| 611 |
+
return re.sub(r"^cls.predictions.bias", "cls.predictions.decoder.bias", key)
|
| 612 |
+
|
| 613 |
+
state_dict = OrderedDict((key_mapping_decoder_bias(k), v) for k, v in state_dict.items())
|
| 614 |
+
|
| 615 |
+
# Word embedding
|
| 616 |
+
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
| 617 |
+
if pad_vocab_size_multiple > 1:
|
| 618 |
+
word_embeddings = state_dict["bert.embeddings.word_embeddings.weight"]
|
| 619 |
+
state_dict["bert.embeddings.word_embeddings.weight"] = F.pad(
|
| 620 |
+
word_embeddings, (0, 0, 0, config.vocab_size - word_embeddings.shape[0])
|
| 621 |
+
)
|
| 622 |
+
decoder_weight = state_dict["cls.predictions.decoder.weight"]
|
| 623 |
+
state_dict["cls.predictions.decoder.weight"] = F.pad(
|
| 624 |
+
decoder_weight, (0, 0, 0, config.vocab_size - decoder_weight.shape[0])
|
| 625 |
+
)
|
| 626 |
+
# If the vocab was padded, we want to set the decoder bias for those padded indices to be
|
| 627 |
+
# strongly negative (i.e. the decoder shouldn't predict those indices).
|
| 628 |
+
# TD [2022-05-09]: I don't think it affects the MLPerf training.
|
| 629 |
+
decoder_bias = state_dict["cls.predictions.decoder.bias"]
|
| 630 |
+
state_dict["cls.predictions.decoder.bias"] = F.pad(
|
| 631 |
+
decoder_bias, (0, config.vocab_size - decoder_bias.shape[0]), value=-100.0
|
| 632 |
+
)
|
| 633 |
+
|
| 634 |
+
return state_dict
|
| 635 |
+
|
| 636 |
+
|
| 637 |
+
def inv_remap_state_dict(state_dict, config: PretrainedConfig):
|
| 638 |
+
"""
|
| 639 |
+
Map the state_dict of a flash_attn model to be Huggingface BERT compatible.
|
| 640 |
+
|
| 641 |
+
This function is meant to be the inverse of remap_state_dict.
|
| 642 |
+
"""
|
| 643 |
+
# Word embedding
|
| 644 |
+
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
| 645 |
+
if pad_vocab_size_multiple > 1:
|
| 646 |
+
word_embeddings = state_dict["bert.embeddings.word_embeddings.weight"]
|
| 647 |
+
decoder_weight = state_dict["cls.predictions.decoder.weight"]
|
| 648 |
+
decoder_bias = state_dict["cls.predictions.decoder.bias"]
|
| 649 |
+
# unpad embeddings
|
| 650 |
+
state_dict["bert.embeddings.word_embeddings.weight"] = word_embeddings[
|
| 651 |
+
: config.orig_vocab_size, :
|
| 652 |
+
]
|
| 653 |
+
state_dict["cls.predictions.decoder.weight"] = decoder_weight[: config.orig_vocab_size, :]
|
| 654 |
+
state_dict["cls.predictions.decoder.bias"] = decoder_bias[: config.orig_vocab_size]
|
| 655 |
+
|
| 656 |
+
for d in range(config.num_hidden_layers):
|
| 657 |
+
last_layer_subset = getattr(config, "last_layer_subset", False)
|
| 658 |
+
if not last_layer_subset or d != (config.num_hidden_layers - 1):
|
| 659 |
+
Wqkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.weight")
|
| 660 |
+
Wqkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.bias")
|
| 661 |
+
state_dict[f"bert.encoder.layers.{d}.attention.self.query.weight"] = Wqkv_weights[
|
| 662 |
+
: Wqkv_weights.shape[0] // 3, :
|
| 663 |
+
]
|
| 664 |
+
state_dict[f"bert.encoder.layers.{d}.attention.self.key.weight"] = Wqkv_weights[
|
| 665 |
+
Wqkv_weights.shape[0] // 3 : 2 * Wqkv_weights.shape[0] // 3, :
|
| 666 |
+
]
|
| 667 |
+
state_dict[f"bert.encoder.layers.{d}.attention.self.value.weight"] = Wqkv_weights[
|
| 668 |
+
2 * Wqkv_weights.shape[0] // 3 :, :
|
| 669 |
+
]
|
| 670 |
+
state_dict[f"bert.encoder.layers.{d}.attention.self.query.bias"] = Wqkv_biases[
|
| 671 |
+
: Wqkv_biases.shape[0] // 3
|
| 672 |
+
]
|
| 673 |
+
state_dict[f"bert.encoder.layers.{d}.attention.self.key.bias"] = Wqkv_biases[
|
| 674 |
+
Wqkv_biases.shape[0] // 3 : 2 * Wqkv_biases.shape[0] // 3
|
| 675 |
+
]
|
| 676 |
+
state_dict[f"bert.encoder.layers.{d}.attention.self.value.bias"] = Wqkv_biases[
|
| 677 |
+
2 * Wqkv_biases.shape[0] // 3 :
|
| 678 |
+
]
|
| 679 |
+
else:
|
| 680 |
+
Wq_weight = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.weight")
|
| 681 |
+
Wkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.weight")
|
| 682 |
+
Wq_bias = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.bias")
|
| 683 |
+
Wkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.bias")
|
| 684 |
+
state_dict[f"bert.encoder.layers.{d}.attention.self.query.weight"] = Wq_weight
|
| 685 |
+
state_dict[f"bert.encoder.layers.{d}.attention.self.key.weight"] = Wkv_weights[
|
| 686 |
+
: Wkv_weights.shape[0] // 2, :
|
| 687 |
+
]
|
| 688 |
+
state_dict[f"bert.encoder.layers.{d}.attention.self.value.weight"] = Wkv_weights[
|
| 689 |
+
Wkv_weights.shape[0] // 2 :, :
|
| 690 |
+
]
|
| 691 |
+
state_dict[f"bert.encoder.layers.{d}.attention.self.query.bias"] = Wq_bias
|
| 692 |
+
state_dict[f"bert.encoder.layers.{d}.attention.self.key.bias"] = Wkv_biases[
|
| 693 |
+
: Wkv_biases.shape[0] // 2
|
| 694 |
+
]
|
| 695 |
+
state_dict[f"bert.encoder.layers.{d}.attention.self.value.bias"] = Wkv_biases[
|
| 696 |
+
Wkv_biases.shape[0] // 2 :
|
| 697 |
+
]
|
| 698 |
+
|
| 699 |
+
def inv_key_mapping_ln(key):
|
| 700 |
+
key = re.sub(r"bert.emb_ln.", "bert.embeddings.LayerNorm.", key)
|
| 701 |
+
key = re.sub(
|
| 702 |
+
r"bert.encoder.layers.(\d+).norm1.(weight|bias)",
|
| 703 |
+
r"bert.encoder.layers.\1.attention.output.LayerNorm.\2",
|
| 704 |
+
key,
|
| 705 |
+
)
|
| 706 |
+
key = re.sub(
|
| 707 |
+
r"bert.encoder.layers.(\d+).norm2.(weight|bias)",
|
| 708 |
+
r"bert.encoder.layers.\1.output.LayerNorm.\2",
|
| 709 |
+
key,
|
| 710 |
+
)
|
| 711 |
+
key = re.sub(
|
| 712 |
+
r"cls.predictions.transform.layer_norm.(weight|bias)",
|
| 713 |
+
r"cls.predictions.transform.LayerNorm.\1",
|
| 714 |
+
key,
|
| 715 |
+
)
|
| 716 |
+
return key
|
| 717 |
+
|
| 718 |
+
def inv_key_mapping_ln_gamma_beta(key):
|
| 719 |
+
key = re.sub(r"LayerNorm.weight$", "LayerNorm.gamma", key)
|
| 720 |
+
key = re.sub(r"LayerNorm.bias$", "LayerNorm.beta", key)
|
| 721 |
+
return key
|
| 722 |
+
|
| 723 |
+
def inv_key_mapping_layers(key):
|
| 724 |
+
return re.sub(r"bert.encoder.layers.", "bert.encoder.layer.", key)
|
| 725 |
+
|
| 726 |
+
def inv_key_mapping_mlp(key):
|
| 727 |
+
key = re.sub(
|
| 728 |
+
r"bert.encoder.layer.(\d+).mlp.fc1.(weight|bias)",
|
| 729 |
+
r"bert.encoder.layer.\1.intermediate.dense.\2",
|
| 730 |
+
key,
|
| 731 |
+
)
|
| 732 |
+
key = re.sub(
|
| 733 |
+
r"bert.encoder.layer.(\d+).mlp.fc2.(weight|bias)",
|
| 734 |
+
r"bert.encoder.layer.\1.output.dense.\2",
|
| 735 |
+
key,
|
| 736 |
+
)
|
| 737 |
+
return key
|
| 738 |
+
|
| 739 |
+
def inv_key_mapping_attn(key):
|
| 740 |
+
return re.sub(
|
| 741 |
+
r"bert.encoder.layer.(\d+).mixer.out_proj.(weight|bias)",
|
| 742 |
+
r"bert.encoder.layer.\1.attention.output.dense.\2",
|
| 743 |
+
key,
|
| 744 |
+
)
|
| 745 |
+
|
| 746 |
+
def inv_key_mapping_decoder_bias(key):
|
| 747 |
+
return re.sub(r"cls.predictions.decoder.bias", "cls.predictions.bias", key)
|
| 748 |
+
|
| 749 |
+
state_dict = OrderedDict((inv_key_mapping_ln(key), value) for key, value in state_dict.items())
|
| 750 |
+
state_dict = OrderedDict(
|
| 751 |
+
(inv_key_mapping_ln_gamma_beta(key), value) for key, value in state_dict.items()
|
| 752 |
+
)
|
| 753 |
+
state_dict = OrderedDict(
|
| 754 |
+
(inv_key_mapping_layers(key), value) for key, value in state_dict.items()
|
| 755 |
+
)
|
| 756 |
+
state_dict = OrderedDict((inv_key_mapping_mlp(key), value) for key, value in state_dict.items())
|
| 757 |
+
state_dict = OrderedDict(
|
| 758 |
+
(inv_key_mapping_attn(key), value) for key, value in state_dict.items()
|
| 759 |
+
)
|
| 760 |
+
state_dict = OrderedDict(
|
| 761 |
+
(inv_key_mapping_decoder_bias(key), value) for key, value in state_dict.items()
|
| 762 |
+
)
|
| 763 |
+
|
| 764 |
+
return state_dict
|
.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/bigcode.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import re
|
| 3 |
+
from collections import OrderedDict
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from transformers import GPT2Config, GPTBigCodeConfig, PretrainedConfig
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def remap_state_dict_hf_bigcode(state_dict, config: PretrainedConfig):
|
| 11 |
+
"""
|
| 12 |
+
Map the state_dict of a Huggingface BigCode model to be flash_attn compatible.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
# Word embedding and position embedding
|
| 16 |
+
def key_mapping_pos_emb(key):
|
| 17 |
+
return re.sub(r"^transformer.wpe.", "transformer.embeddings.position_embeddings.", key)
|
| 18 |
+
|
| 19 |
+
state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items())
|
| 20 |
+
word_embeddings = state_dict.pop("transformer.wte.weight")
|
| 21 |
+
# It's possible that vocab_size is padded to be a multiple of 8, for example.
|
| 22 |
+
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
| 23 |
+
vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
|
| 24 |
+
state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
|
| 25 |
+
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
|
| 26 |
+
)
|
| 27 |
+
state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
|
| 28 |
+
|
| 29 |
+
# LayerNorm
|
| 30 |
+
def key_mapping_ln(key):
|
| 31 |
+
key = re.sub(r"^transformer.ln_f.(weight|bias)", r"transformer.ln_f.\1", key)
|
| 32 |
+
key = re.sub(
|
| 33 |
+
r"^transformer.h.(\d+).ln_(1|2).(weight|bias)",
|
| 34 |
+
r"transformer.layers.\1.norm\2.\3",
|
| 35 |
+
key,
|
| 36 |
+
)
|
| 37 |
+
return key
|
| 38 |
+
|
| 39 |
+
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
|
| 40 |
+
|
| 41 |
+
def key_mapping_mlp(key):
|
| 42 |
+
key = re.sub(
|
| 43 |
+
r"^transformer.h.(\d+).mlp.c_fc.weight",
|
| 44 |
+
r"transformer.layers.\1.mlp.fc1.weight",
|
| 45 |
+
key,
|
| 46 |
+
)
|
| 47 |
+
key = re.sub(
|
| 48 |
+
r"^transformer.h.(\d+).mlp.c_proj.weight",
|
| 49 |
+
r"transformer.layers.\1.mlp.fc2.weight",
|
| 50 |
+
key,
|
| 51 |
+
)
|
| 52 |
+
key = re.sub(
|
| 53 |
+
r"^transformer.h.(\d+).mlp.c_fc.bias",
|
| 54 |
+
r"transformer.layers.\1.mlp.fc1.bias",
|
| 55 |
+
key,
|
| 56 |
+
)
|
| 57 |
+
key = re.sub(
|
| 58 |
+
r"^transformer.h.(\d+).mlp.c_proj.bias",
|
| 59 |
+
r"transformer.layers.\1.mlp.fc2.bias",
|
| 60 |
+
key,
|
| 61 |
+
)
|
| 62 |
+
return key
|
| 63 |
+
|
| 64 |
+
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
|
| 65 |
+
|
| 66 |
+
# TODO: add support for multi-head attention
|
| 67 |
+
assert config.multi_query, "Only multi-query attention is supported"
|
| 68 |
+
|
| 69 |
+
# Attention
|
| 70 |
+
for d in range(config.num_hidden_layers):
|
| 71 |
+
embed_dim = config.n_embd
|
| 72 |
+
head_dim = embed_dim // config.n_head
|
| 73 |
+
|
| 74 |
+
c_attn_weight = state_dict.pop(f"transformer.h.{d}.attn.c_attn.weight")
|
| 75 |
+
# with multi-query attention, the weights have shape (embed_dim, embed_dim + head_dim + head_dim)
|
| 76 |
+
# see https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py#L112
|
| 77 |
+
# see also https://github.com/ggerganov/ggml/blob/dd1d575956e54c5bdc07632f25506b3b1884dbd2/examples/starcoder/convert-hf-to-ggml.py#L183
|
| 78 |
+
# ((n_head + 2) * head_dim, embed_dim) -> (3 * n_heads * head_dim, hidden_dim)
|
| 79 |
+
q, k, v = torch.split(c_attn_weight, [embed_dim, head_dim, head_dim], dim=0)
|
| 80 |
+
# duplicate k, v along the first axis (head_dim, hidden_dim) -> (n_heads * head_dim, hidden_dim)
|
| 81 |
+
k = torch.tile(k, (config.n_head, 1))
|
| 82 |
+
v = torch.tile(v, (config.n_head, 1))
|
| 83 |
+
state_dict[f"transformer.layers.{d}.mixer.Wqkv.weight"] = torch.cat((q, k, v), dim=0)
|
| 84 |
+
|
| 85 |
+
# same deal with the bias
|
| 86 |
+
c_attn_bias = state_dict.pop(f"transformer.h.{d}.attn.c_attn.bias")
|
| 87 |
+
# ((n_head + 2) * head_dim, embed_dim) -> (3 * n_heads * head_dim, hidden_dim)
|
| 88 |
+
q, k, v = torch.split(c_attn_bias, [embed_dim, head_dim, head_dim], dim=0)
|
| 89 |
+
# duplicate k, v along the first axis (head_dim, hidden_dim) -> (n_heads * head_dim, hidden_dim)
|
| 90 |
+
k = torch.tile(k, (config.n_head,))
|
| 91 |
+
v = torch.tile(v, (config.n_head,))
|
| 92 |
+
state_dict[f"transformer.layers.{d}.mixer.Wqkv.bias"] = torch.cat((q, k, v), dim=0)
|
| 93 |
+
|
| 94 |
+
def key_mapping_attn(key):
|
| 95 |
+
key = re.sub(
|
| 96 |
+
r"^transformer.h.(\d+).attn.c_proj.weight",
|
| 97 |
+
r"transformer.layers.\1.mixer.out_proj.weight",
|
| 98 |
+
key,
|
| 99 |
+
)
|
| 100 |
+
key = re.sub(
|
| 101 |
+
r"^transformer.h.(\d+).attn.c_proj.bias",
|
| 102 |
+
r"transformer.layers.\1.mixer.out_proj.bias",
|
| 103 |
+
key,
|
| 104 |
+
)
|
| 105 |
+
return key
|
| 106 |
+
|
| 107 |
+
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
|
| 108 |
+
|
| 109 |
+
return state_dict
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def inv_remap_state_dict_hf_bigcode(state_dict, config: PretrainedConfig):
|
| 113 |
+
"""
|
| 114 |
+
Map the state_dict of a flash_attn model to be Huggingface BigCode compatible.
|
| 115 |
+
|
| 116 |
+
This function is meant to be the inverse of remap_state_dict_hf_bigcode.
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
# Word embedding and position embeddings
|
| 120 |
+
def inv_key_mapping_pos_emb(key):
|
| 121 |
+
return re.sub(r"^transformer.embeddings.position_embeddings.", "transformer.wpe.", key)
|
| 122 |
+
|
| 123 |
+
state_dict = OrderedDict((inv_key_mapping_pos_emb(k), v) for k, v in state_dict.items())
|
| 124 |
+
word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight")
|
| 125 |
+
|
| 126 |
+
word_embeddings = word_embeddings[:, : config.vocab_size]
|
| 127 |
+
state_dict["transformer.wte.weight"] = word_embeddings
|
| 128 |
+
state_dict["lm_head.weight"] = word_embeddings
|
| 129 |
+
|
| 130 |
+
# LayerNorm
|
| 131 |
+
def inv_key_mapping_ln(key):
|
| 132 |
+
key = re.sub(r"^transformer.ln_f.(weight|bias)", r"transformer.ln_f.\1", key)
|
| 133 |
+
key = re.sub(
|
| 134 |
+
r"^transformer.layers.(\d+).norm(1|2).(weight|bias)",
|
| 135 |
+
r"transformer.h.\1.ln_\2.\3",
|
| 136 |
+
key,
|
| 137 |
+
)
|
| 138 |
+
return key
|
| 139 |
+
|
| 140 |
+
state_dict = OrderedDict((inv_key_mapping_ln(k), v) for k, v in state_dict.items())
|
| 141 |
+
|
| 142 |
+
# MLPs
|
| 143 |
+
def inv_key_mapping_mlp(key):
|
| 144 |
+
key = re.sub(
|
| 145 |
+
r"^transformer.layers.(\d+).mlp.fc1.weight",
|
| 146 |
+
r"transformer.h.\1.mlp.c_fc.weight",
|
| 147 |
+
key,
|
| 148 |
+
)
|
| 149 |
+
key = re.sub(
|
| 150 |
+
r"^transformer.layers.(\d+).mlp.fc2.weight",
|
| 151 |
+
r"transformer.h.\1.mlp.c_proj.weight",
|
| 152 |
+
key,
|
| 153 |
+
)
|
| 154 |
+
key = re.sub(
|
| 155 |
+
r"^transformer.layers.(\d+).mlp.fc1.bias",
|
| 156 |
+
r"transformer.h.\1.mlp.c_fc.bias",
|
| 157 |
+
key,
|
| 158 |
+
)
|
| 159 |
+
key = re.sub(
|
| 160 |
+
r"^transformer.layers.(\d+).mlp.fc2.bias",
|
| 161 |
+
r"transformer.h.\1.mlp.c_proj.bias",
|
| 162 |
+
key,
|
| 163 |
+
)
|
| 164 |
+
return key
|
| 165 |
+
|
| 166 |
+
state_dict = OrderedDict((inv_key_mapping_mlp(k), v) for k, v in state_dict.items())
|
| 167 |
+
|
| 168 |
+
# Attention
|
| 169 |
+
for d in range(config.num_hidden_layers):
|
| 170 |
+
embed_dim = config.n_embd
|
| 171 |
+
head_dim = embed_dim // config.n_head
|
| 172 |
+
|
| 173 |
+
Wqkv_weight = state_dict.pop(f"transformer.layers.{d}.mixer.Wqkv.weight")
|
| 174 |
+
q, k, v = torch.split(
|
| 175 |
+
Wqkv_weight, [embed_dim, head_dim * config.n_head, head_dim * config.n_head], dim=0
|
| 176 |
+
)
|
| 177 |
+
c_attn_weight = torch.cat((q, k[:head_dim], v[:head_dim]), dim=0)
|
| 178 |
+
state_dict[f"transformer.h.{d}.attn.c_attn.weight"] = c_attn_weight
|
| 179 |
+
|
| 180 |
+
# Same deal with the bias
|
| 181 |
+
Wqkv_bias = state_dict.pop(f"transformer.layers.{d}.mixer.Wqkv.bias")
|
| 182 |
+
q, k, v = torch.split(
|
| 183 |
+
Wqkv_bias, [embed_dim, head_dim * config.n_head, head_dim * config.n_head], dim=0
|
| 184 |
+
)
|
| 185 |
+
c_attn_bias = torch.cat((q, k[:head_dim], v[:head_dim]), dim=0)
|
| 186 |
+
state_dict[f"transformer.h.{d}.attn.c_attn.bias"] = c_attn_bias
|
| 187 |
+
|
| 188 |
+
def inv_key_mapping_attn(key):
|
| 189 |
+
key = re.sub(
|
| 190 |
+
r"^transformer.layers.(\d+).mixer.out_proj.weight",
|
| 191 |
+
r"transformer.h.\1.attn.c_proj.weight",
|
| 192 |
+
key,
|
| 193 |
+
)
|
| 194 |
+
key = re.sub(
|
| 195 |
+
r"^transformer.layers.(\d+).mixer.out_proj.bias",
|
| 196 |
+
r"transformer.h.\1.attn.c_proj.bias",
|
| 197 |
+
key,
|
| 198 |
+
)
|
| 199 |
+
return key
|
| 200 |
+
|
| 201 |
+
state_dict = OrderedDict((inv_key_mapping_attn(k), v) for k, v in state_dict.items())
|
| 202 |
+
|
| 203 |
+
return state_dict
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def bigcode_config_to_gpt2_config(bigcode_config: GPTBigCodeConfig) -> GPT2Config:
|
| 207 |
+
return GPT2Config(
|
| 208 |
+
activation_function=bigcode_config.activation_function,
|
| 209 |
+
attn_pdrop=bigcode_config.attn_pdrop,
|
| 210 |
+
bos_token_id=bigcode_config.bos_token_id,
|
| 211 |
+
embd_pdrop=bigcode_config.embd_pdrop,
|
| 212 |
+
eos_token_id=bigcode_config.eos_token_id,
|
| 213 |
+
initializer_range=bigcode_config.initializer_range,
|
| 214 |
+
layer_norm_epsilon=bigcode_config.layer_norm_epsilon,
|
| 215 |
+
max_batch_size=bigcode_config.max_batch_size,
|
| 216 |
+
max_sequence_length=bigcode_config.max_sequence_length,
|
| 217 |
+
model_type=bigcode_config.model_type,
|
| 218 |
+
multi_query=bigcode_config.multi_query,
|
| 219 |
+
n_embd=bigcode_config.n_embd,
|
| 220 |
+
n_head=bigcode_config.n_head,
|
| 221 |
+
n_inner=bigcode_config.n_inner,
|
| 222 |
+
n_layer=bigcode_config.n_layer,
|
| 223 |
+
n_positions=bigcode_config.n_positions,
|
| 224 |
+
resid_pdrop=bigcode_config.resid_pdrop,
|
| 225 |
+
scale_attn_weights=bigcode_config.scale_attn_weights,
|
| 226 |
+
summary_activation=bigcode_config.summary_activation,
|
| 227 |
+
summary_first_dropout=bigcode_config.summary_first_dropout,
|
| 228 |
+
summary_proj_to_labels=bigcode_config.summary_proj_to_labels,
|
| 229 |
+
summary_type=bigcode_config.summary_type,
|
| 230 |
+
summary_use_proj=bigcode_config.summary_use_proj,
|
| 231 |
+
use_cache=bigcode_config.use_cache,
|
| 232 |
+
vocab_size=bigcode_config.vocab_size,
|
| 233 |
+
)
|
.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/btlm.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023, Tri Dao.
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
import json
|
| 5 |
+
import re
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
from collections import OrderedDict
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
|
| 13 |
+
from einops import rearrange
|
| 14 |
+
from transformers import GPT2Config, AutoConfig, PretrainedConfig
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def remap_state_dict_hf_btlm(state_dict, config):
|
| 18 |
+
# Word embedding and position embedding
|
| 19 |
+
def key_mapping_pos_emb(key):
|
| 20 |
+
return re.sub(r"^transformer.wpe.", "transformer.embeddings.position_embeddings.", key)
|
| 21 |
+
|
| 22 |
+
if "transformer.wpe.weight" in state_dict:
|
| 23 |
+
state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items())
|
| 24 |
+
word_embeddings = state_dict.pop("transformer.wte.weight")
|
| 25 |
+
# It's possible that vocab_size is padded to be a multiple of 8, for example.
|
| 26 |
+
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
| 27 |
+
vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
|
| 28 |
+
state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
|
| 29 |
+
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
|
| 30 |
+
)
|
| 31 |
+
state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
|
| 32 |
+
|
| 33 |
+
# LayerNorm
|
| 34 |
+
def key_mapping_ln(key):
|
| 35 |
+
key = re.sub(r"^transformer.ln_f.(weight|bias)", r"transformer.ln_f.\1", key)
|
| 36 |
+
key = re.sub(r"^transformer.h.(\d+).ln_(1|2).(weight|bias)", r"transformer.layers.\1.norm\2.\3", key)
|
| 37 |
+
return key
|
| 38 |
+
|
| 39 |
+
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
|
| 40 |
+
|
| 41 |
+
# MLP
|
| 42 |
+
for d in range(config.num_hidden_layers):
|
| 43 |
+
W1 = state_dict.pop(f"transformer.h.{d}.mlp.c_fc.weight")
|
| 44 |
+
W3 = state_dict.pop(f"transformer.h.{d}.mlp.c_fc2.weight")
|
| 45 |
+
state_dict[f"transformer.layers.{d}.mlp.fc1.weight"] = torch.cat([W1.t(), W3.t()], dim=0)
|
| 46 |
+
b1 = state_dict.pop(f"transformer.h.{d}.mlp.c_fc.bias")
|
| 47 |
+
b3 = state_dict.pop(f"transformer.h.{d}.mlp.c_fc2.bias")
|
| 48 |
+
state_dict[f"transformer.layers.{d}.mlp.fc1.bias"] = torch.cat([b1, b3], dim=0)
|
| 49 |
+
W2 = state_dict.pop(f"transformer.h.{d}.mlp.c_proj.weight")
|
| 50 |
+
state_dict[f"transformer.layers.{d}.mlp.fc2.weight"] = W2.t()
|
| 51 |
+
|
| 52 |
+
def key_mapping_mlp(key):
|
| 53 |
+
key = re.sub(r"^transformer.h.(\d+).mlp.c_proj.bias", r"transformer.layers.\1.mlp.fc2.bias", key)
|
| 54 |
+
return key
|
| 55 |
+
|
| 56 |
+
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
|
| 57 |
+
|
| 58 |
+
# Attention
|
| 59 |
+
for d in range(config.num_hidden_layers):
|
| 60 |
+
Wqkv = state_dict.pop(f"transformer.h.{d}.attn.c_attn.weight")
|
| 61 |
+
state_dict[f"transformer.layers.{d}.mixer.Wqkv.weight"] = Wqkv.t()
|
| 62 |
+
Wout = state_dict.pop(f"transformer.h.{d}.attn.c_proj.weight")
|
| 63 |
+
state_dict[f"transformer.layers.{d}.mixer.out_proj.weight"] = Wout.t()
|
| 64 |
+
state_dict.pop(f"transformer.relative_pe.slopes") # We don't store the Alibi slopes
|
| 65 |
+
|
| 66 |
+
def key_mapping_attn(key):
|
| 67 |
+
key = re.sub(r"^transformer.h.(\d+).attn.c_attn.bias", r"transformer.layers.\1.mixer.Wqkv.bias", key)
|
| 68 |
+
key = re.sub(
|
| 69 |
+
r"^transformer.h.(\d+).attn.c_proj.bias", r"transformer.layers.\1.mixer.out_proj.bias", key
|
| 70 |
+
)
|
| 71 |
+
return key
|
| 72 |
+
|
| 73 |
+
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
|
| 74 |
+
|
| 75 |
+
return state_dict
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def btlm_config_to_gpt2_config(btlm_config: PretrainedConfig) -> GPT2Config:
|
| 79 |
+
return GPT2Config(
|
| 80 |
+
vocab_size=btlm_config.vocab_size,
|
| 81 |
+
n_positions=0 if btlm_config.position_embedding_type == "alibi" else btlm_config.n_positions,
|
| 82 |
+
n_embd=btlm_config.hidden_size,
|
| 83 |
+
n_layer=btlm_config.num_hidden_layers,
|
| 84 |
+
n_head=btlm_config.num_attention_heads,
|
| 85 |
+
n_inner=btlm_config.n_inner,
|
| 86 |
+
activation_function=btlm_config.activation_function,
|
| 87 |
+
resid_pdrop=btlm_config.resid_pdrop,
|
| 88 |
+
embd_pdrop=btlm_config.embd_pdrop,
|
| 89 |
+
attn_pdrop=btlm_config.attn_pdrop,
|
| 90 |
+
layer_norm_epsilon=btlm_config.layer_norm_epsilon,
|
| 91 |
+
initializer_range=btlm_config.initializer_range,
|
| 92 |
+
bos_token_id=btlm_config.bos_token_id,
|
| 93 |
+
eos_token_id=btlm_config.eos_token_id,
|
| 94 |
+
# These are new arguments not in the original GPT2Config
|
| 95 |
+
use_alibi=btlm_config.position_embedding_type == "alibi",
|
| 96 |
+
use_flash_attn=btlm_config.position_embedding_type == "alibi", # Alibi code path requires flash_attn
|
| 97 |
+
mup_width_scale=btlm_config.mup_width_scale,
|
| 98 |
+
mup_embeddings_multiplier=btlm_config.mup_embeddings_scale,
|
| 99 |
+
mup_output_multiplier=btlm_config.mup_output_alpha,
|
| 100 |
+
mup_scale_qk_dot_by_d=btlm_config.mup_scale_qk_dot_by_d,
|
| 101 |
+
mlp_multiple_of=1,
|
| 102 |
+
)
|
.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/falcon.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023, Tri Dao.
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
import re
|
| 5 |
+
from collections import OrderedDict
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from einops import rearrange
|
| 10 |
+
from transformers import FalconConfig, GPT2Config
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def remap_state_dict_hf_falcon(state_dict, config):
|
| 14 |
+
def key_mapping_layers(key):
|
| 15 |
+
return re.sub(r"^transformer.h.", "transformer.layers.", key)
|
| 16 |
+
|
| 17 |
+
state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
|
| 18 |
+
# Word embedding
|
| 19 |
+
def key_mapping_emb(key):
|
| 20 |
+
return re.sub(
|
| 21 |
+
r"^transformer.word_embeddings.", "transformer.embeddings.word_embeddings.", key
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
|
| 25 |
+
word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight")
|
| 26 |
+
# It's possible that vocab_size is padded to be a multiple of 8, for example.
|
| 27 |
+
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
| 28 |
+
vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
|
| 29 |
+
state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
|
| 30 |
+
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
|
| 31 |
+
)
|
| 32 |
+
if getattr(config, "tie_word_embeddings"):
|
| 33 |
+
state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
|
| 34 |
+
else:
|
| 35 |
+
output_embeddings = state_dict.pop("lm_head.weight")
|
| 36 |
+
# It's possible that vocab_size is padded to be a multiple of 8, for example.
|
| 37 |
+
state_dict["lm_head.weight"] = F.pad(
|
| 38 |
+
output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
|
| 39 |
+
)
|
| 40 |
+
output_embeddings_bias = state_dict.pop("lm_head.bias")
|
| 41 |
+
state_dict["lm_head.bias"] = F.pad(
|
| 42 |
+
output_embeddings_bias, (0, vocab_size - output_embeddings_bias.shape[0])
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
# LayerNorm
|
| 46 |
+
def key_mapping_ln(key):
|
| 47 |
+
key = re.sub(
|
| 48 |
+
r"^transformer.layers.(\d+).input_layernorm.", r"transformer.layers.\1.norm1.", key
|
| 49 |
+
)
|
| 50 |
+
key = re.sub(
|
| 51 |
+
r"^transformer.layers.(\d+).post_attention_layernorm.",
|
| 52 |
+
r"transformer.layers.\1.norm2.",
|
| 53 |
+
key,
|
| 54 |
+
)
|
| 55 |
+
key = re.sub(r"^transformer.layers.(\d+).ln_attn.", r"transformer.layers.\1.norm1.", key)
|
| 56 |
+
key = re.sub(r"^transformer.layers.(\d+).ln_mlp.", r"transformer.layers.\1.norm2.", key)
|
| 57 |
+
return key
|
| 58 |
+
|
| 59 |
+
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
|
| 60 |
+
|
| 61 |
+
# MLP
|
| 62 |
+
def key_mapping_mlp(key):
|
| 63 |
+
key = re.sub(
|
| 64 |
+
r"^transformer.layers.(\d+).mlp.dense_h_to_4h.", r"transformer.layers.\1.mlp.fc1.", key
|
| 65 |
+
)
|
| 66 |
+
key = re.sub(
|
| 67 |
+
r"^transformer.layers.(\d+).mlp.dense_4h_to_h.", r"transformer.layers.\1.mlp.fc2.", key
|
| 68 |
+
)
|
| 69 |
+
return key
|
| 70 |
+
|
| 71 |
+
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
|
| 72 |
+
|
| 73 |
+
def key_mapping_attn(key):
|
| 74 |
+
key = re.sub(
|
| 75 |
+
r"^transformer.layers.(\d+).self_attention.query_key_value.",
|
| 76 |
+
r"transformer.layers.\1.mixer.Wqkv.",
|
| 77 |
+
key,
|
| 78 |
+
)
|
| 79 |
+
key = re.sub(
|
| 80 |
+
r"^transformer.layers.(\d+).self_attention.dense.",
|
| 81 |
+
r"transformer.layers.\1.mixer.out_proj.",
|
| 82 |
+
key,
|
| 83 |
+
)
|
| 84 |
+
return key
|
| 85 |
+
|
| 86 |
+
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
|
| 87 |
+
n_head = config.n_head
|
| 88 |
+
n_head_kv = getattr(config, "n_head_kv", 1)
|
| 89 |
+
headdim = config.hidden_size // n_head
|
| 90 |
+
for l in range(config.n_layer):
|
| 91 |
+
# The weights are stored in a different layout compared to our implementation
|
| 92 |
+
Wqkv = rearrange(
|
| 93 |
+
state_dict.pop(f"transformer.layers.{l}.mixer.Wqkv.weight"),
|
| 94 |
+
"(group ratio headdim) ... -> group ratio headdim ...",
|
| 95 |
+
ratio=n_head // n_head_kv + 2,
|
| 96 |
+
headdim=headdim,
|
| 97 |
+
)
|
| 98 |
+
Wq = rearrange(Wqkv[:, :-2], "group ratio headdim ... -> (group ratio headdim) ...")
|
| 99 |
+
Wk = rearrange(Wqkv[:, [-2]], "group ratio headdim ... -> (group ratio headdim) ...")
|
| 100 |
+
Wv = rearrange(Wqkv[:, [-1]], "group ratio headdim ... -> (group ratio headdim) ...")
|
| 101 |
+
state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat([Wq, Wk, Wv], dim=0)
|
| 102 |
+
|
| 103 |
+
return state_dict
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def falcon_config_to_gpt2_config(falcon_config: FalconConfig) -> GPT2Config:
|
| 107 |
+
# The 40b config uses "n_head_kv" instead of "num_kv_heads"
|
| 108 |
+
n_head_kv = getattr(
|
| 109 |
+
falcon_config,
|
| 110 |
+
"n_head_kv",
|
| 111 |
+
1 if getattr(falcon_config, "multi_query", False) else falcon_config.n_head,
|
| 112 |
+
)
|
| 113 |
+
# HACK: the 40b config has 2 LN per layer instead of 1, but that's not reflected in the config.
|
| 114 |
+
# So we have to infer it from the number of heads in the key/value block
|
| 115 |
+
parallel_block_tied_norm = n_head_kv == 1
|
| 116 |
+
return GPT2Config(
|
| 117 |
+
vocab_size=falcon_config.vocab_size,
|
| 118 |
+
n_positions=0, # No absolute position embedding
|
| 119 |
+
n_embd=falcon_config.hidden_size,
|
| 120 |
+
n_layer=falcon_config.n_layer,
|
| 121 |
+
n_head=falcon_config.n_head,
|
| 122 |
+
n_inner=falcon_config.hidden_size * 4,
|
| 123 |
+
activation_function="gelu",
|
| 124 |
+
resid_pdrop=falcon_config.hidden_dropout,
|
| 125 |
+
embd_pdrop=0.0, # There doesn't seem to be any embedding dropout
|
| 126 |
+
attn_pdrop=falcon_config.attention_dropout,
|
| 127 |
+
layer_norm_epsilon=falcon_config.layer_norm_epsilon,
|
| 128 |
+
initializer_range=falcon_config.initializer_range,
|
| 129 |
+
bos_token_id=falcon_config.bos_token_id,
|
| 130 |
+
eos_token_id=falcon_config.eos_token_id,
|
| 131 |
+
# These are new arguments not in the original GPT2Config
|
| 132 |
+
parallel_block=falcon_config.parallel_attn,
|
| 133 |
+
n_head_kv=n_head_kv,
|
| 134 |
+
parallel_block_tied_norm=parallel_block_tied_norm,
|
| 135 |
+
rotary_emb_fraction=1.0,
|
| 136 |
+
rotary_emb_interleaved=False,
|
| 137 |
+
tie_word_embeddings=True,
|
| 138 |
+
qkv_proj_bias=falcon_config.bias,
|
| 139 |
+
out_proj_bias=falcon_config.bias,
|
| 140 |
+
mlp_fc1_bias=falcon_config.bias,
|
| 141 |
+
mlp_fc2_bias=falcon_config.bias,
|
| 142 |
+
lm_head_bias=False,
|
| 143 |
+
)
|
.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/gpt.py
ADDED
|
@@ -0,0 +1,1080 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024, Tri Dao.
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
import math
|
| 5 |
+
import re
|
| 6 |
+
from collections import OrderedDict, namedtuple
|
| 7 |
+
from collections.abc import Sequence
|
| 8 |
+
from functools import partial
|
| 9 |
+
from typing import Dict, List
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
from einops import rearrange
|
| 15 |
+
from transformers import GPT2Config
|
| 16 |
+
|
| 17 |
+
from flash_attn.models.bigcode import remap_state_dict_hf_bigcode
|
| 18 |
+
from flash_attn.models.falcon import remap_state_dict_hf_falcon
|
| 19 |
+
from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox
|
| 20 |
+
from flash_attn.models.gptj import remap_state_dict_hf_gptj
|
| 21 |
+
from flash_attn.models.llama import remap_state_dict_hf_llama
|
| 22 |
+
from flash_attn.models.opt import remap_state_dict_hf_opt
|
| 23 |
+
from flash_attn.modules.block import Block, ParallelBlock
|
| 24 |
+
from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings
|
| 25 |
+
from flash_attn.modules.mha import MHA, ParallelMHA
|
| 26 |
+
from flash_attn.modules.mlp import (
|
| 27 |
+
FusedMLP,
|
| 28 |
+
GatedMlp,
|
| 29 |
+
Mlp,
|
| 30 |
+
ParallelFusedMLP,
|
| 31 |
+
ParallelGatedMlp,
|
| 32 |
+
ParallelMLP,
|
| 33 |
+
)
|
| 34 |
+
from flash_attn.ops.activations import sqrelu_fwd
|
| 35 |
+
from flash_attn.utils.distributed import (
|
| 36 |
+
all_gather,
|
| 37 |
+
all_gather_raw,
|
| 38 |
+
get_dim_for_local_rank,
|
| 39 |
+
sync_shared_params,
|
| 40 |
+
)
|
| 41 |
+
from flash_attn.utils.generation import GenerationMixin
|
| 42 |
+
from flash_attn.utils.pretrained import state_dict_from_pretrained
|
| 43 |
+
|
| 44 |
+
try:
|
| 45 |
+
from flash_attn.ops.fused_dense import ColumnParallelLinear
|
| 46 |
+
except ImportError:
|
| 47 |
+
ColumnParallelLinear = None
|
| 48 |
+
|
| 49 |
+
try:
|
| 50 |
+
from flash_attn.ops.triton.mlp import FusedDenseSqreluDense
|
| 51 |
+
except ImportError:
|
| 52 |
+
FusedDenseSqreluDense = None
|
| 53 |
+
|
| 54 |
+
try:
|
| 55 |
+
from flash_attn.ops.triton.layer_norm import layer_norm_fn, RMSNorm
|
| 56 |
+
except ImportError:
|
| 57 |
+
layer_norm_fn, RMSNorm = None, None
|
| 58 |
+
|
| 59 |
+
logger = logging.getLogger(__name__)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dtype=None):
|
| 63 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 64 |
+
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
| 65 |
+
attn_scale_power = 0.5 if not getattr(config, "mup_scale_qk_dot_by_d", False) else 1.0
|
| 66 |
+
softmax_scale = 1.0 if not config.scale_attn_weights else (head_dim ** (-attn_scale_power))
|
| 67 |
+
softmax_scale *= getattr(config, "mup_attn_multiplier", 1.0)
|
| 68 |
+
if config.scale_attn_by_inverse_layer_idx:
|
| 69 |
+
assert layer_idx is not None
|
| 70 |
+
softmax_scale /= float(layer_idx + 1)
|
| 71 |
+
dwconv = getattr(config, "attn_dwconv", False)
|
| 72 |
+
if dwconv:
|
| 73 |
+
assert process_group is None, "TensorParallel MHA does not support dwconv yet"
|
| 74 |
+
qkv_proj_bias = getattr(config, "qkv_proj_bias", True)
|
| 75 |
+
out_proj_bias = getattr(config, "out_proj_bias", True)
|
| 76 |
+
rotary_emb_dim = int(getattr(config, "rotary_emb_fraction", 0.0) * head_dim)
|
| 77 |
+
rotary_emb_base = getattr(config, "rotary_emb_base", 10000.0)
|
| 78 |
+
rotary_emb_scale_base = getattr(config, "rotary_emb_scale_base", None)
|
| 79 |
+
rotary_emb_interleaved = getattr(config, "rotary_emb_interleaved", False)
|
| 80 |
+
use_alibi = getattr(config, "use_alibi", False)
|
| 81 |
+
window_size = getattr(config, "window_size", (-1, -1))
|
| 82 |
+
use_flash_attn = getattr(config, "use_flash_attn", False)
|
| 83 |
+
fused_bias_fc = getattr(config, "fused_bias_fc", False)
|
| 84 |
+
if not fused_bias_fc:
|
| 85 |
+
assert process_group is None, "TensorParallel MHA requires fused_bias_fc"
|
| 86 |
+
mha_cls = MHA if process_group is None else ParallelMHA
|
| 87 |
+
serial_kwargs = (
|
| 88 |
+
{"fused_bias_fc": fused_bias_fc, "dwconv": dwconv} if process_group is None else {}
|
| 89 |
+
)
|
| 90 |
+
parallel_kwargs = (
|
| 91 |
+
{
|
| 92 |
+
"process_group": process_group,
|
| 93 |
+
"sequence_parallel": getattr(config, "sequence_parallel", True),
|
| 94 |
+
}
|
| 95 |
+
if process_group is not None
|
| 96 |
+
else {}
|
| 97 |
+
)
|
| 98 |
+
num_heads_kv = getattr(config, "n_head_kv", None)
|
| 99 |
+
mixer_cls = partial(
|
| 100 |
+
mha_cls,
|
| 101 |
+
num_heads=config.num_attention_heads,
|
| 102 |
+
num_heads_kv=num_heads_kv,
|
| 103 |
+
qkv_proj_bias=qkv_proj_bias,
|
| 104 |
+
out_proj_bias=out_proj_bias,
|
| 105 |
+
dropout=config.attn_pdrop,
|
| 106 |
+
softmax_scale=softmax_scale,
|
| 107 |
+
causal=True,
|
| 108 |
+
layer_idx=layer_idx,
|
| 109 |
+
rotary_emb_dim=rotary_emb_dim,
|
| 110 |
+
rotary_emb_base=rotary_emb_base,
|
| 111 |
+
rotary_emb_scale_base=rotary_emb_scale_base,
|
| 112 |
+
rotary_emb_interleaved=rotary_emb_interleaved,
|
| 113 |
+
use_alibi=use_alibi,
|
| 114 |
+
window_size=window_size,
|
| 115 |
+
use_flash_attn=use_flash_attn,
|
| 116 |
+
**serial_kwargs,
|
| 117 |
+
**parallel_kwargs,
|
| 118 |
+
**factory_kwargs,
|
| 119 |
+
)
|
| 120 |
+
return mixer_cls
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtype=None):
|
| 124 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 125 |
+
mlp_fc1_bias = getattr(config, "mlp_fc1_bias", True)
|
| 126 |
+
mlp_fc2_bias = getattr(config, "mlp_fc2_bias", True)
|
| 127 |
+
fused_mlp = getattr(config, "fused_mlp", False)
|
| 128 |
+
if fused_mlp:
|
| 129 |
+
assert config.activation_function in [
|
| 130 |
+
"gelu_new",
|
| 131 |
+
"gelu_fast",
|
| 132 |
+
"gelu_approx",
|
| 133 |
+
"gelu_pytorch_tanh",
|
| 134 |
+
"relu",
|
| 135 |
+
"sqrelu",
|
| 136 |
+
]
|
| 137 |
+
fused_dense_sqrelu_dense = getattr(config, "fused_dense_sqrelu_dense", False)
|
| 138 |
+
if fused_dense_sqrelu_dense:
|
| 139 |
+
assert config.activation_function == "sqrelu", (
|
| 140 |
+
"fused_dense_sqrelu_dense only " "supports approximate activation_function sqrelu"
|
| 141 |
+
)
|
| 142 |
+
assert not (fused_dense_sqrelu_dense and fused_mlp)
|
| 143 |
+
if not fused_mlp and not fused_dense_sqrelu_dense:
|
| 144 |
+
assert config.activation_function in [
|
| 145 |
+
"gelu",
|
| 146 |
+
"gelu_new",
|
| 147 |
+
"gelu_fast",
|
| 148 |
+
"gelu_approx",
|
| 149 |
+
"gelu_pytorch_tanh",
|
| 150 |
+
"relu",
|
| 151 |
+
"sqrelu",
|
| 152 |
+
"glu",
|
| 153 |
+
"swiglu",
|
| 154 |
+
"geglu",
|
| 155 |
+
]
|
| 156 |
+
if config.activation_function in ["glu", "swiglu", "geglu"]:
|
| 157 |
+
activation = (
|
| 158 |
+
F.sigmoid
|
| 159 |
+
if config.activation_function == "glu"
|
| 160 |
+
else (F.silu if config.activation_function == "swiglu" else F.gelu)
|
| 161 |
+
)
|
| 162 |
+
mlp_cls = GatedMlp if process_group is None else ParallelGatedMlp
|
| 163 |
+
parallel_kwargs = (
|
| 164 |
+
{
|
| 165 |
+
"process_group": process_group,
|
| 166 |
+
"sequence_parallel": getattr(config, "sequence_parallel", True),
|
| 167 |
+
}
|
| 168 |
+
if process_group is not None
|
| 169 |
+
else {}
|
| 170 |
+
)
|
| 171 |
+
mlp_multiple_of = getattr(config, "mlp_multiple_of", 128)
|
| 172 |
+
mlp_cls = partial(
|
| 173 |
+
mlp_cls,
|
| 174 |
+
hidden_features=config.n_inner,
|
| 175 |
+
activation=activation,
|
| 176 |
+
bias1=mlp_fc1_bias,
|
| 177 |
+
bias2=mlp_fc2_bias,
|
| 178 |
+
multiple_of=mlp_multiple_of,
|
| 179 |
+
**parallel_kwargs,
|
| 180 |
+
**factory_kwargs,
|
| 181 |
+
)
|
| 182 |
+
else:
|
| 183 |
+
if config.activation_function == "relu":
|
| 184 |
+
activation = partial(F.relu, inplace=True)
|
| 185 |
+
elif config.activation_function == "sqrelu":
|
| 186 |
+
activation = sqrelu_fwd
|
| 187 |
+
else:
|
| 188 |
+
approximate = (
|
| 189 |
+
"tanh"
|
| 190 |
+
if config.activation_function
|
| 191 |
+
in ["gelu_new", "gelu_fast", "gelu_approx", "gelu_pytorch_tanh"]
|
| 192 |
+
else "none"
|
| 193 |
+
)
|
| 194 |
+
activation = partial(F.gelu, approximate=approximate)
|
| 195 |
+
mlp_cls = Mlp if process_group is None else ParallelMLP
|
| 196 |
+
parallel_kwargs = (
|
| 197 |
+
{
|
| 198 |
+
"process_group": process_group,
|
| 199 |
+
"sequence_parallel": getattr(config, "sequence_parallel", True),
|
| 200 |
+
}
|
| 201 |
+
if process_group is not None
|
| 202 |
+
else {}
|
| 203 |
+
)
|
| 204 |
+
mlp_cls = partial(
|
| 205 |
+
mlp_cls,
|
| 206 |
+
hidden_features=config.n_inner,
|
| 207 |
+
activation=activation,
|
| 208 |
+
bias1=mlp_fc1_bias,
|
| 209 |
+
bias2=mlp_fc2_bias,
|
| 210 |
+
**parallel_kwargs,
|
| 211 |
+
**factory_kwargs,
|
| 212 |
+
)
|
| 213 |
+
else:
|
| 214 |
+
mlp_checkpoint_lvl = getattr(config, "mlp_checkpoint_lvl", 0)
|
| 215 |
+
# mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
|
| 216 |
+
if isinstance(mlp_checkpoint_lvl, Sequence):
|
| 217 |
+
assert layer_idx is not None
|
| 218 |
+
mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
|
| 219 |
+
if fused_mlp:
|
| 220 |
+
if FusedMLP is None:
|
| 221 |
+
raise ImportError("fused_dense is not installed")
|
| 222 |
+
activation = (
|
| 223 |
+
"gelu_approx"
|
| 224 |
+
if config.activation_function
|
| 225 |
+
in ["gelu_new", "gelu_fast", "gelu_approx", "gelu_pytorch_tanh"]
|
| 226 |
+
else config.activation_function
|
| 227 |
+
)
|
| 228 |
+
mlp_cls = FusedMLP if process_group is None else ParallelFusedMLP
|
| 229 |
+
parallel_kwargs = (
|
| 230 |
+
{
|
| 231 |
+
"process_group": process_group,
|
| 232 |
+
"sequence_parallel": getattr(config, "sequence_parallel", True),
|
| 233 |
+
}
|
| 234 |
+
if process_group is not None
|
| 235 |
+
else {}
|
| 236 |
+
)
|
| 237 |
+
mlp_cls = partial(
|
| 238 |
+
mlp_cls,
|
| 239 |
+
hidden_features=config.n_inner,
|
| 240 |
+
activation=activation,
|
| 241 |
+
checkpoint_lvl=mlp_checkpoint_lvl,
|
| 242 |
+
bias1=mlp_fc1_bias,
|
| 243 |
+
bias2=mlp_fc2_bias,
|
| 244 |
+
**parallel_kwargs,
|
| 245 |
+
**factory_kwargs,
|
| 246 |
+
)
|
| 247 |
+
elif fused_dense_sqrelu_dense:
|
| 248 |
+
if process_group is not None:
|
| 249 |
+
assert fused_mlp, "Tensor Parallel is not implemented for FusedDenseSqreluDense"
|
| 250 |
+
assert FusedDenseSqreluDense is not None
|
| 251 |
+
mlp_cls = partial(
|
| 252 |
+
FusedDenseSqreluDense,
|
| 253 |
+
hidden_features=config.n_inner,
|
| 254 |
+
checkpoint_lvl=mlp_checkpoint_lvl,
|
| 255 |
+
**factory_kwargs,
|
| 256 |
+
)
|
| 257 |
+
else:
|
| 258 |
+
raise RuntimeError("MLP type not supported")
|
| 259 |
+
return mlp_cls
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def create_block(config, layer_idx=None, process_group=None, device=None, dtype=None):
|
| 263 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 264 |
+
sequence_parallel = getattr(config, "sequence_parallel", True)
|
| 265 |
+
mixer_cls = create_mixer_cls(config, layer_idx, process_group=process_group, **factory_kwargs)
|
| 266 |
+
mlp_cls = create_mlp_cls(config, layer_idx, process_group=process_group, **factory_kwargs)
|
| 267 |
+
use_rms_norm = getattr(config, "rms_norm", False)
|
| 268 |
+
norm_cls = partial(
|
| 269 |
+
nn.LayerNorm if not use_rms_norm else RMSNorm,
|
| 270 |
+
eps=config.layer_norm_epsilon,
|
| 271 |
+
**factory_kwargs,
|
| 272 |
+
)
|
| 273 |
+
# TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
|
| 274 |
+
residual_in_fp32 = getattr(config, "residual_in_fp32", False)
|
| 275 |
+
resid_dropout1 = config.resid_pdrop if layer_idx is None or layer_idx > 0 else config.embd_pdrop
|
| 276 |
+
prenorm = getattr(config, "prenorm", True)
|
| 277 |
+
parallel_block = getattr(config, "parallel_block", False)
|
| 278 |
+
if not parallel_block:
|
| 279 |
+
block = Block(
|
| 280 |
+
config.hidden_size,
|
| 281 |
+
mixer_cls,
|
| 282 |
+
mlp_cls,
|
| 283 |
+
norm_cls=norm_cls,
|
| 284 |
+
prenorm=prenorm,
|
| 285 |
+
resid_dropout1=resid_dropout1,
|
| 286 |
+
resid_dropout2=config.resid_pdrop,
|
| 287 |
+
fused_dropout_add_ln=getattr(config, "fused_dropout_add_ln", False),
|
| 288 |
+
residual_in_fp32=residual_in_fp32,
|
| 289 |
+
sequence_parallel=sequence_parallel and process_group is not None,
|
| 290 |
+
mark_shared_params=process_group is not None,
|
| 291 |
+
)
|
| 292 |
+
else:
|
| 293 |
+
assert prenorm
|
| 294 |
+
block = ParallelBlock(
|
| 295 |
+
config.hidden_size,
|
| 296 |
+
mixer_cls,
|
| 297 |
+
mlp_cls,
|
| 298 |
+
norm_cls=norm_cls,
|
| 299 |
+
resid_dropout1=resid_dropout1,
|
| 300 |
+
resid_dropout2=config.resid_pdrop,
|
| 301 |
+
tied_norm=getattr(config, "parallel_block_tied_norm", False),
|
| 302 |
+
fused_dropout_add_ln=getattr(config, "fused_dropout_add_ln", False),
|
| 303 |
+
residual_in_fp32=residual_in_fp32,
|
| 304 |
+
sequence_parallel=sequence_parallel and process_group is not None,
|
| 305 |
+
mark_shared_params=process_group is not None,
|
| 306 |
+
)
|
| 307 |
+
block.layer_idx = layer_idx
|
| 308 |
+
return block
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
class GPTPreTrainedModel(nn.Module):
|
| 312 |
+
"""An abstract class to handle weights initialization and
|
| 313 |
+
a simple interface for dowloading and loading pretrained models.
|
| 314 |
+
"""
|
| 315 |
+
|
| 316 |
+
def __init__(self, config, *inputs, **kwargs):
|
| 317 |
+
super().__init__()
|
| 318 |
+
if not isinstance(config, GPT2Config):
|
| 319 |
+
raise ValueError(
|
| 320 |
+
"Parameter config in `{}(config)` should be an instance of class `GPT2Config`. "
|
| 321 |
+
"To create a model from a Google pretrained model use "
|
| 322 |
+
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
|
| 323 |
+
self.__class__.__name__, self.__class__.__name__
|
| 324 |
+
)
|
| 325 |
+
)
|
| 326 |
+
self.config = config
|
| 327 |
+
|
| 328 |
+
@classmethod
|
| 329 |
+
def from_pretrained(
|
| 330 |
+
cls,
|
| 331 |
+
model_name,
|
| 332 |
+
config,
|
| 333 |
+
*args,
|
| 334 |
+
strict=True,
|
| 335 |
+
device=None,
|
| 336 |
+
dtype=None,
|
| 337 |
+
world_size=1,
|
| 338 |
+
rank=0,
|
| 339 |
+
**kwargs,
|
| 340 |
+
):
|
| 341 |
+
"""
|
| 342 |
+
Instantiate a GPTPreTrainedModel from a pre-trained model file or a pytorch state dict.
|
| 343 |
+
Download and cache the pre-trained model file if needed.
|
| 344 |
+
"""
|
| 345 |
+
# Instantiate model.
|
| 346 |
+
model = cls(config, *args, device=device, dtype=dtype, **kwargs)
|
| 347 |
+
# Load state_dict in cpu because we already initialized the model in GPU, and we don't
|
| 348 |
+
# want extra stuff taking up more GPU memory
|
| 349 |
+
state_dict = state_dict_from_pretrained(model_name, device="cpu", dtype=dtype)
|
| 350 |
+
if model_name.startswith("gpt2"):
|
| 351 |
+
state_dict = remap_state_dict_hf_gpt2(state_dict, config)
|
| 352 |
+
elif model_name.startswith("facebook/opt"):
|
| 353 |
+
state_dict = remap_state_dict_hf_opt(state_dict, config)
|
| 354 |
+
elif model_name.startswith("EleutherAI/gpt-j-") or model_name.startswith(
|
| 355 |
+
"togethercomputer/GPT-JT-"
|
| 356 |
+
):
|
| 357 |
+
state_dict = remap_state_dict_hf_gptj(state_dict, config)
|
| 358 |
+
elif (
|
| 359 |
+
model_name.startswith("EleutherAI/gpt-neox-")
|
| 360 |
+
or model_name.startswith("EleutherAI/pythia-")
|
| 361 |
+
or model_name.startswith("togethercomputer/RedPajama-INCITE-")
|
| 362 |
+
):
|
| 363 |
+
state_dict = remap_state_dict_hf_gpt_neox(state_dict, config)
|
| 364 |
+
elif model_name.startswith("tiiuae/falcon-"):
|
| 365 |
+
state_dict = remap_state_dict_hf_falcon(state_dict, config)
|
| 366 |
+
elif model_name.startswith("meta-llama/Llama-"):
|
| 367 |
+
state_dict = remap_state_dict_hf_llama(state_dict, config)
|
| 368 |
+
elif model_name.startswith("bigcode/") or model_name.startswith("WizardLM/"):
|
| 369 |
+
state_dict = remap_state_dict_hf_bigcode(state_dict, config)
|
| 370 |
+
else:
|
| 371 |
+
raise NotImplementedError(f"Model {model_name} not supported")
|
| 372 |
+
if world_size > 1:
|
| 373 |
+
state_dict = shard_state_dict_tp(state_dict, config, world_size, rank)
|
| 374 |
+
load_return = model.load_state_dict(state_dict, strict=strict)
|
| 375 |
+
logger.info(load_return)
|
| 376 |
+
return model
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
|
| 380 |
+
def _init_weights(
|
| 381 |
+
module, n_layer, initializer_range=0.02, mup_width_scale=1.0, rescale_prenorm_residual=True
|
| 382 |
+
):
|
| 383 |
+
mup_init_scale = math.sqrt(mup_width_scale)
|
| 384 |
+
if isinstance(module, nn.Linear):
|
| 385 |
+
nn.init.normal_(module.weight, std=initializer_range * mup_init_scale)
|
| 386 |
+
optim_cfg = getattr(module.weight, "_optim", {})
|
| 387 |
+
optim_cfg.update({"lr_multiplier": mup_width_scale})
|
| 388 |
+
setattr(module.weight, "_optim", optim_cfg)
|
| 389 |
+
if module.bias is not None:
|
| 390 |
+
nn.init.zeros_(module.bias)
|
| 391 |
+
elif isinstance(module, nn.Embedding):
|
| 392 |
+
nn.init.normal_(module.weight, std=initializer_range)
|
| 393 |
+
|
| 394 |
+
if rescale_prenorm_residual:
|
| 395 |
+
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
|
| 396 |
+
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
|
| 397 |
+
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
|
| 398 |
+
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
|
| 399 |
+
#
|
| 400 |
+
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
|
| 401 |
+
for name, p in module.named_parameters():
|
| 402 |
+
if name in ["out_proj.weight", "fc2.weight"]:
|
| 403 |
+
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
|
| 404 |
+
nn.init.normal_(
|
| 405 |
+
p, mean=0.0, std=initializer_range * mup_init_scale / math.sqrt(2 * n_layer)
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
class GPTModel(GPTPreTrainedModel):
|
| 410 |
+
def __init__(self, config: GPT2Config, process_group=None, device=None, dtype=None):
|
| 411 |
+
super().__init__(config)
|
| 412 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 413 |
+
self.process_group = process_group
|
| 414 |
+
self.sequence_parallel = getattr(config, "sequence_parallel", True)
|
| 415 |
+
assert config.activation_function in [
|
| 416 |
+
"gelu",
|
| 417 |
+
"gelu_new",
|
| 418 |
+
"gelu_fast",
|
| 419 |
+
"gelu_approx",
|
| 420 |
+
"gelu_pytorch_tanh",
|
| 421 |
+
"relu",
|
| 422 |
+
"sqrelu",
|
| 423 |
+
"glu",
|
| 424 |
+
"swiglu",
|
| 425 |
+
"geglu",
|
| 426 |
+
]
|
| 427 |
+
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
| 428 |
+
vocab_size = (
|
| 429 |
+
math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
|
| 430 |
+
)
|
| 431 |
+
self.embeddings_multiplier = getattr(config, "mup_embeddings_multiplier", 1.0)
|
| 432 |
+
# TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
|
| 433 |
+
self.residual_in_fp32 = getattr(config, "residual_in_fp32", False)
|
| 434 |
+
# These 2 options are for OPT-350m
|
| 435 |
+
self.prenorm = getattr(config, "prenorm", True)
|
| 436 |
+
use_rms_norm = getattr(config, "rms_norm", False)
|
| 437 |
+
word_embed_proj_dim = getattr(config, "word_embed_proj_dim", None)
|
| 438 |
+
# For GPT-J, GPT-NeoX
|
| 439 |
+
self.parallel_block = getattr(config, "parallel_block", False)
|
| 440 |
+
|
| 441 |
+
if process_group is None:
|
| 442 |
+
self.embeddings = GPT2Embeddings(
|
| 443 |
+
config.hidden_size,
|
| 444 |
+
vocab_size,
|
| 445 |
+
config.max_position_embeddings,
|
| 446 |
+
word_embed_proj_dim=word_embed_proj_dim,
|
| 447 |
+
**factory_kwargs,
|
| 448 |
+
)
|
| 449 |
+
else:
|
| 450 |
+
self.embeddings = ParallelGPT2Embeddings(
|
| 451 |
+
config.hidden_size,
|
| 452 |
+
vocab_size,
|
| 453 |
+
config.max_position_embeddings,
|
| 454 |
+
process_group=process_group,
|
| 455 |
+
sequence_parallel=self.sequence_parallel,
|
| 456 |
+
**factory_kwargs,
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
# We change the order of dropout, residual and layer norm:
|
| 460 |
+
# Instead of LN -> Attn / MLP -> Dropout -> Add, we do:
|
| 461 |
+
# Dropout -> Add -> LN -> Attn / MLP, returning both the residual branch (output of Add) and
|
| 462 |
+
# the main branch (output of MLP). The model definition is unchanged, but the mapping of the
|
| 463 |
+
# nn.Dropout probabilities are changed.
|
| 464 |
+
# This is for performance reason: we can fuse dropout + add + layer_norm.
|
| 465 |
+
self.layers = nn.ModuleList(
|
| 466 |
+
[
|
| 467 |
+
create_block(config, layer_idx=i, process_group=process_group, **factory_kwargs)
|
| 468 |
+
for i in range(config.num_hidden_layers)
|
| 469 |
+
]
|
| 470 |
+
)
|
| 471 |
+
rotary_emb_fraction = getattr(config, "rotary_emb_fraction", 0.0)
|
| 472 |
+
if rotary_emb_fraction > 0.0: # Tie all the RotaryEmbedding modules to share the same cos/sin cache
|
| 473 |
+
for layer in self.layers[1:]:
|
| 474 |
+
layer.mixer.rotary_emb = self.layers[0].mixer.rotary_emb
|
| 475 |
+
|
| 476 |
+
self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
|
| 477 |
+
if self.fused_dropout_add_ln:
|
| 478 |
+
if layer_norm_fn is None:
|
| 479 |
+
raise ImportError("Triton is not installed")
|
| 480 |
+
if self.prenorm:
|
| 481 |
+
self.drop_f = nn.Dropout(config.resid_pdrop)
|
| 482 |
+
norm_cls = nn.LayerNorm if not use_rms_norm else RMSNorm
|
| 483 |
+
self.ln_f = norm_cls(
|
| 484 |
+
config.hidden_size, eps=config.layer_norm_epsilon, **factory_kwargs
|
| 485 |
+
)
|
| 486 |
+
if process_group is not None:
|
| 487 |
+
for p in self.ln_f.parameters():
|
| 488 |
+
# Mark the norm parameters as "shared_params" so that we sync their values at init.
|
| 489 |
+
p._shared_params = True
|
| 490 |
+
# Mark the norm params as "sequence_parallel" so we run all-reduce on their grads.
|
| 491 |
+
if self.sequence_parallel:
|
| 492 |
+
p._sequence_parallel = True
|
| 493 |
+
|
| 494 |
+
self.apply(
|
| 495 |
+
partial(
|
| 496 |
+
_init_weights,
|
| 497 |
+
n_layer=config.num_hidden_layers,
|
| 498 |
+
initializer_range=config.initializer_range,
|
| 499 |
+
mup_width_scale=getattr(config, "mup_width_scale", 1.0),
|
| 500 |
+
)
|
| 501 |
+
)
|
| 502 |
+
self.tie_weights()
|
| 503 |
+
|
| 504 |
+
def tie_weights(self):
|
| 505 |
+
if self.process_group is not None:
|
| 506 |
+
sync_shared_params(self, self.process_group)
|
| 507 |
+
|
| 508 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
| 509 |
+
return {
|
| 510 |
+
i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
|
| 511 |
+
for i, layer in enumerate(self.layers)
|
| 512 |
+
}
|
| 513 |
+
|
| 514 |
+
def forward(self, input_ids, position_ids=None, inference_params=None):
|
| 515 |
+
# If using Tensor Parallel with sequence parallel, we combine the batch and the seqlen
|
| 516 |
+
# dimensions so that we can split on it easily, in case of small batch size.
|
| 517 |
+
# Only the attention layers need to know the seqlen.
|
| 518 |
+
embedding_kwargs = (
|
| 519 |
+
{"combine_batch_seqlen_dim": True}
|
| 520 |
+
if self.process_group is not None and self.sequence_parallel
|
| 521 |
+
else {}
|
| 522 |
+
)
|
| 523 |
+
hidden_states = self.embeddings(input_ids, position_ids=position_ids, **embedding_kwargs)
|
| 524 |
+
if self.embeddings_multiplier != 1.0:
|
| 525 |
+
hidden_states = hidden_states * self.embeddings_multiplier
|
| 526 |
+
if self.parallel_block:
|
| 527 |
+
hidden_states2 = None
|
| 528 |
+
residual = None
|
| 529 |
+
mixer_kwargs = (
|
| 530 |
+
{"seqlen": input_ids.shape[1]}
|
| 531 |
+
if self.process_group is not None and self.sequence_parallel
|
| 532 |
+
else {}
|
| 533 |
+
)
|
| 534 |
+
if inference_params is not None:
|
| 535 |
+
mixer_kwargs["inference_params"] = inference_params
|
| 536 |
+
for layer in self.layers:
|
| 537 |
+
if self.prenorm:
|
| 538 |
+
if not self.parallel_block:
|
| 539 |
+
hidden_states, residual = layer(
|
| 540 |
+
hidden_states, residual, mixer_kwargs=mixer_kwargs
|
| 541 |
+
)
|
| 542 |
+
else:
|
| 543 |
+
hidden_states, hidden_states2, residual = layer(
|
| 544 |
+
hidden_states, hidden_states2, residual, mixer_kwargs=mixer_kwargs
|
| 545 |
+
)
|
| 546 |
+
else:
|
| 547 |
+
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
| 548 |
+
if self.prenorm:
|
| 549 |
+
if not self.fused_dropout_add_ln:
|
| 550 |
+
dropped = self.drop_f(hidden_states)
|
| 551 |
+
if not self.parallel_block:
|
| 552 |
+
residual = (dropped + residual) if residual is not None else dropped
|
| 553 |
+
else:
|
| 554 |
+
dropped2 = self.drop_f(hidden_states2)
|
| 555 |
+
residual = (
|
| 556 |
+
(residual + dropped + dropped2)
|
| 557 |
+
if residual is not None
|
| 558 |
+
else dropped + dropped2
|
| 559 |
+
)
|
| 560 |
+
hidden_states = self.ln_f(residual.to(dtype=self.ln_f.weight.dtype))
|
| 561 |
+
else:
|
| 562 |
+
# Set prenorm=False here since we don't need the residual
|
| 563 |
+
hidden_states = layer_norm_fn(
|
| 564 |
+
hidden_states,
|
| 565 |
+
self.ln_f.weight,
|
| 566 |
+
self.ln_f.bias,
|
| 567 |
+
residual=residual,
|
| 568 |
+
x1=None if not self.parallel_block else hidden_states2,
|
| 569 |
+
eps=self.ln_f.eps,
|
| 570 |
+
dropout_p=self.drop_f.p if self.training else 0.0,
|
| 571 |
+
prenorm=False,
|
| 572 |
+
is_rms_norm=isinstance(self.ln_f, RMSNorm)
|
| 573 |
+
)
|
| 574 |
+
return hidden_states
|
| 575 |
+
|
| 576 |
+
|
| 577 |
+
class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
|
| 578 |
+
def __init__(self, config: GPT2Config, process_group=None, device=None, dtype=None):
|
| 579 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 580 |
+
super().__init__(config)
|
| 581 |
+
self.process_group = process_group
|
| 582 |
+
self.transformer = GPTModel(config, process_group=process_group, **factory_kwargs)
|
| 583 |
+
self.tie_word_embeddings = getattr(config, "tie_word_embeddings", True)
|
| 584 |
+
lm_head_bias = getattr(config, "lm_head_bias", False)
|
| 585 |
+
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
| 586 |
+
vocab_size = (
|
| 587 |
+
math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
|
| 588 |
+
)
|
| 589 |
+
# This option is for OPT-350m
|
| 590 |
+
word_embed_proj_dim = getattr(config, "word_embed_proj_dim", None)
|
| 591 |
+
embed_dim = config.n_embd if word_embed_proj_dim is None else word_embed_proj_dim
|
| 592 |
+
if word_embed_proj_dim is not None:
|
| 593 |
+
self.project_out = nn.Linear(config.n_embd, embed_dim, bias=False, **factory_kwargs)
|
| 594 |
+
else:
|
| 595 |
+
self.project_out = None
|
| 596 |
+
mup_width_scale = getattr(config, "mup_width_scale", 1.0)
|
| 597 |
+
mup_output_multiplier = getattr(config, "mup_output_multiplier", 1.0)
|
| 598 |
+
self.output_scale = mup_output_multiplier * mup_width_scale
|
| 599 |
+
if process_group is None:
|
| 600 |
+
self.lm_head = nn.Linear(embed_dim, vocab_size, bias=lm_head_bias, **factory_kwargs)
|
| 601 |
+
else:
|
| 602 |
+
if ColumnParallelLinear is None:
|
| 603 |
+
raise ImportError("fused_dense_lib is not installed")
|
| 604 |
+
self.lm_head = ColumnParallelLinear(
|
| 605 |
+
embed_dim,
|
| 606 |
+
vocab_size,
|
| 607 |
+
process_group,
|
| 608 |
+
bias=lm_head_bias,
|
| 609 |
+
sequence_parallel=getattr(config, "sequence_parallel", True),
|
| 610 |
+
**factory_kwargs,
|
| 611 |
+
)
|
| 612 |
+
self.norm_head = getattr(config, "norm_head", False)
|
| 613 |
+
# Initialize weights and apply final processing
|
| 614 |
+
self.apply(
|
| 615 |
+
partial(
|
| 616 |
+
_init_weights,
|
| 617 |
+
n_layer=config.num_hidden_layers,
|
| 618 |
+
initializer_range=config.initializer_range,
|
| 619 |
+
mup_width_scale=mup_width_scale,
|
| 620 |
+
)
|
| 621 |
+
)
|
| 622 |
+
self.tie_weights()
|
| 623 |
+
|
| 624 |
+
def tie_weights(self):
|
| 625 |
+
if self.tie_word_embeddings:
|
| 626 |
+
self.lm_head.weight = self.transformer.embeddings.word_embeddings.weight
|
| 627 |
+
if self.process_group is not None:
|
| 628 |
+
sync_shared_params(self, self.process_group)
|
| 629 |
+
|
| 630 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
| 631 |
+
return self.transformer.allocate_inference_cache(
|
| 632 |
+
batch_size, max_seqlen, dtype=dtype, **kwargs
|
| 633 |
+
)
|
| 634 |
+
|
| 635 |
+
def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0):
|
| 636 |
+
"""
|
| 637 |
+
input_ids: (batch, seqlen) int tensor
|
| 638 |
+
inference_params: for generation. Adapted from Megatron-LM (and Apex)
|
| 639 |
+
https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
|
| 640 |
+
num_last_tokens: if > 0, only return the logits for the last n tokens
|
| 641 |
+
"""
|
| 642 |
+
assert (
|
| 643 |
+
input_ids.ndim == 2
|
| 644 |
+
), f"Expected `input_ids` to have shape [b, slen], but got shape {input_ids.shape}"
|
| 645 |
+
b, slen = input_ids.shape
|
| 646 |
+
hidden_states = self.transformer(
|
| 647 |
+
input_ids, position_ids=position_ids, inference_params=inference_params
|
| 648 |
+
)
|
| 649 |
+
if inference_params is not None:
|
| 650 |
+
assert hidden_states.ndim == 3, "sequence_parallel is not supported in generation mode"
|
| 651 |
+
if num_last_tokens > 0:
|
| 652 |
+
hidden_states = hidden_states[:, -num_last_tokens:]
|
| 653 |
+
if self.project_out is not None:
|
| 654 |
+
hidden_states = self.project_out(hidden_states)
|
| 655 |
+
if self.output_scale != 1.0:
|
| 656 |
+
hidden_states = hidden_states * self.output_scale
|
| 657 |
+
if not self.norm_head:
|
| 658 |
+
lm_logits = self.lm_head(hidden_states)
|
| 659 |
+
else:
|
| 660 |
+
lm_head_weight = F.normalize(self.lm_head.weight)
|
| 661 |
+
if isinstance(self.lm_head, ColumnParallelLinear) and self.lm_head.sequence_parallel:
|
| 662 |
+
hidden_states = all_gather(hidden_states, self.lm_head.process_group)
|
| 663 |
+
lm_logits = F.linear(hidden_states, lm_head_weight, bias=self.lm_head.bias)
|
| 664 |
+
# During inference, we want the full logit for sampling
|
| 665 |
+
if isinstance(self.lm_head, ColumnParallelLinear) and inference_params is not None:
|
| 666 |
+
lm_logits, _ = all_gather_raw(lm_logits, self.lm_head.process_group)
|
| 667 |
+
lm_logits = rearrange(lm_logits, "(n b) ... d -> b ... (n d)", b=b)
|
| 668 |
+
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
|
| 669 |
+
return CausalLMOutput(logits=lm_logits)
|
| 670 |
+
|
| 671 |
+
def load_state_dict(self, state_dict, strict=True):
|
| 672 |
+
# Remapping from our checkpoints that used a different ordering of layers in the block
|
| 673 |
+
# Previous: Attn / MLP -> Dropout -> Add -> LN
|
| 674 |
+
# Current: Dropout -> Add -> LN -> Attn / MLP
|
| 675 |
+
if "transformer.ln_0.weight" in state_dict:
|
| 676 |
+
n_layers = len(self.transformer.layers)
|
| 677 |
+
ln_weight = state_dict.pop(f"transformer.layers.{n_layers - 1}.norm2.weight")
|
| 678 |
+
ln_bias = state_dict.pop(f"transformer.layers.{n_layers - 1}.norm2.bias")
|
| 679 |
+
state_dict["transformer.ln_f.weight"] = ln_weight
|
| 680 |
+
state_dict["transformer.ln_f.bias"] = ln_bias
|
| 681 |
+
for l in reversed(range(n_layers)):
|
| 682 |
+
ln_weight = state_dict.pop(f"transformer.layers.{l}.norm1.weight")
|
| 683 |
+
ln_bias = state_dict.pop(f"transformer.layers.{l}.norm1.bias")
|
| 684 |
+
state_dict[f"transformer.layers.{l}.norm2.weight"] = ln_weight
|
| 685 |
+
state_dict[f"transformer.layers.{l}.norm2.bias"] = ln_bias
|
| 686 |
+
if l > 0:
|
| 687 |
+
ln_weight = state_dict.pop(f"transformer.layers.{l - 1}.norm2.weight")
|
| 688 |
+
ln_bias = state_dict.pop(f"transformer.layers.{l - 1}.norm2.bias")
|
| 689 |
+
state_dict[f"transformer.layers.{l}.norm1.weight"] = ln_weight
|
| 690 |
+
state_dict[f"transformer.layers.{l}.norm1.bias"] = ln_bias
|
| 691 |
+
ln_weight = state_dict.pop("transformer.ln_0.weight")
|
| 692 |
+
ln_bias = state_dict.pop("transformer.ln_0.bias")
|
| 693 |
+
state_dict[f"transformer.layers.0.norm1.weight"] = ln_weight
|
| 694 |
+
state_dict[f"transformer.layers.0.norm1.bias"] = ln_bias
|
| 695 |
+
return super().load_state_dict(state_dict, strict=strict)
|
| 696 |
+
|
| 697 |
+
|
| 698 |
+
def shard_state_dict_tp(state_dict, config, world_size, rank):
|
| 699 |
+
"""Convert the state_dict of a standard GPT model to the state_dict of a GPT model
|
| 700 |
+
with tensor parallel.
|
| 701 |
+
|
| 702 |
+
This function modifies state_dict in place.
|
| 703 |
+
"""
|
| 704 |
+
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
| 705 |
+
vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
|
| 706 |
+
assert vocab_size % world_size == 0
|
| 707 |
+
assert config.hidden_size % world_size == 0
|
| 708 |
+
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
|
| 709 |
+
assert inner_dim % world_size == 0
|
| 710 |
+
|
| 711 |
+
n_head = config.n_head
|
| 712 |
+
n_head_kv = getattr(config, "n_head_kv", n_head)
|
| 713 |
+
|
| 714 |
+
embed_dim = config.hidden_size
|
| 715 |
+
head_dim = embed_dim // n_head
|
| 716 |
+
|
| 717 |
+
def shard_first_dim(state_dict, key):
|
| 718 |
+
if key in state_dict:
|
| 719 |
+
x = state_dict[key]
|
| 720 |
+
dim = x.shape[0] // world_size
|
| 721 |
+
state_dict[key] = x[rank * dim : (rank + 1) * dim]
|
| 722 |
+
|
| 723 |
+
def shard_last_dim(state_dict, key, multiple_of=1):
|
| 724 |
+
if key in state_dict:
|
| 725 |
+
x = state_dict[key]
|
| 726 |
+
dim_each_rank = [
|
| 727 |
+
get_dim_for_local_rank(x.size(-1), world_size, local_rank, multiple_of)
|
| 728 |
+
for local_rank in range(world_size)
|
| 729 |
+
]
|
| 730 |
+
beg, end = tuple(sum(dim_each_rank[:pos]) for pos in (rank, rank + 1))
|
| 731 |
+
state_dict[key] = x[..., beg:end]
|
| 732 |
+
|
| 733 |
+
def shard_gatedmlp_fc1_dim(state_dict, key):
|
| 734 |
+
if key in state_dict:
|
| 735 |
+
x = state_dict[key]
|
| 736 |
+
dim = x.shape[0] // world_size // 2
|
| 737 |
+
state_dict[key] = rearrange(
|
| 738 |
+
rearrange(x, "(two o) ... -> two o ...", two=2)[:, rank * dim : (rank + 1) * dim],
|
| 739 |
+
"two o ... -> (two o) ...",
|
| 740 |
+
)
|
| 741 |
+
|
| 742 |
+
def shard_qkv_headdim(state_dict, key):
|
| 743 |
+
if key in state_dict:
|
| 744 |
+
n_head_each_rank = [
|
| 745 |
+
get_dim_for_local_rank(n_head, world_size, local_rank)
|
| 746 |
+
for local_rank in range(world_size)
|
| 747 |
+
]
|
| 748 |
+
n_head_kv_each_rank = [
|
| 749 |
+
get_dim_for_local_rank(n_head_kv, world_size, local_rank)
|
| 750 |
+
for local_rank in range(world_size)
|
| 751 |
+
]
|
| 752 |
+
|
| 753 |
+
beg_n_head = sum(n_head_each_rank[:rank])
|
| 754 |
+
end_n_head = sum(n_head_each_rank[: rank + 1])
|
| 755 |
+
|
| 756 |
+
beg_n_head_kv = sum(n_head_kv_each_rank[:rank])
|
| 757 |
+
end_n_head_kv = sum(n_head_kv_each_rank[: rank + 1])
|
| 758 |
+
|
| 759 |
+
if n_head_kv == n_head:
|
| 760 |
+
x = rearrange(state_dict[key], "(three d) ... -> three d ...", three=3)
|
| 761 |
+
state_dict[key] = rearrange(
|
| 762 |
+
x[:, beg_n_head * head_dim : end_n_head * head_dim],
|
| 763 |
+
"three d ... -> (three d) ...",
|
| 764 |
+
)
|
| 765 |
+
else:
|
| 766 |
+
x = rearrange(
|
| 767 |
+
state_dict[key],
|
| 768 |
+
"(nheadqkv headdim) ... -> nheadqkv headdim ...",
|
| 769 |
+
nheadqkv=n_head + 2 * n_head_kv,
|
| 770 |
+
)
|
| 771 |
+
state_dict[key] = rearrange(
|
| 772 |
+
torch.cat(
|
| 773 |
+
[
|
| 774 |
+
x[beg_n_head:end_n_head],
|
| 775 |
+
x[n_head + beg_n_head_kv : n_head + end_n_head_kv],
|
| 776 |
+
x[
|
| 777 |
+
n_head
|
| 778 |
+
+ n_head_kv
|
| 779 |
+
+ beg_n_head_kv : n_head
|
| 780 |
+
+ n_head_kv
|
| 781 |
+
+ end_n_head_kv
|
| 782 |
+
],
|
| 783 |
+
],
|
| 784 |
+
dim=0,
|
| 785 |
+
),
|
| 786 |
+
"nheadqkv headdim ... -> (nheadqkv headdim) ...",
|
| 787 |
+
)
|
| 788 |
+
|
| 789 |
+
shard_first_dim(state_dict, "transformer.embeddings.word_embeddings.weight")
|
| 790 |
+
if "lm_head.weight" in state_dict:
|
| 791 |
+
shard_first_dim(state_dict, "lm_head.weight")
|
| 792 |
+
if "transformer.embeddings.position_embeddings.weight" in state_dict:
|
| 793 |
+
shard_last_dim(state_dict, "transformer.embeddings.position_embeddings.weight")
|
| 794 |
+
for i in range(config.num_hidden_layers):
|
| 795 |
+
shard_qkv_headdim(state_dict, f"transformer.layers.{i}.mixer.Wqkv.weight")
|
| 796 |
+
shard_qkv_headdim(state_dict, f"transformer.layers.{i}.mixer.Wqkv.bias")
|
| 797 |
+
shard_last_dim(
|
| 798 |
+
state_dict, f"transformer.layers.{i}.mixer.out_proj.weight", multiple_of=head_dim
|
| 799 |
+
)
|
| 800 |
+
if rank != 0:
|
| 801 |
+
state_dict.pop(f"transformer.layers.{i}.mixer.out_proj.bias", None)
|
| 802 |
+
if config.activation_function in ["glu", "swiglu", "geglu"]:
|
| 803 |
+
shard_gatedmlp_fc1_dim(state_dict, f"transformer.layers.{i}.mlp.fc1.weight")
|
| 804 |
+
shard_gatedmlp_fc1_dim(state_dict, f"transformer.layers.{i}.mlp.fc1.bias")
|
| 805 |
+
else:
|
| 806 |
+
shard_first_dim(state_dict, f"transformer.layers.{i}.mlp.fc1.weight")
|
| 807 |
+
shard_first_dim(state_dict, f"transformer.layers.{i}.mlp.fc1.bias")
|
| 808 |
+
shard_last_dim(state_dict, f"transformer.layers.{i}.mlp.fc2.weight")
|
| 809 |
+
if rank != 0:
|
| 810 |
+
state_dict.pop(f"transformer.layers.{i}.mlp.fc2.bias", None)
|
| 811 |
+
return state_dict
|
| 812 |
+
|
| 813 |
+
|
| 814 |
+
def combine_state_dicts_tp(state_dicts: List[Dict[str, torch.Tensor]], config: GPT2Config):
|
| 815 |
+
"""Convert the list of sharded state_dict of a GPT model with tensor parallel to
|
| 816 |
+
the state_dict of a standard GPT model.
|
| 817 |
+
|
| 818 |
+
This function is meant to be the "reverse" of shard_state_dict_tp.
|
| 819 |
+
|
| 820 |
+
Precondition:
|
| 821 |
+
- state_dicts should be ordered in the same way as the shards were created.
|
| 822 |
+
"""
|
| 823 |
+
world_size = len(state_dicts)
|
| 824 |
+
keys = state_dicts[0].keys()
|
| 825 |
+
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
| 826 |
+
vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
|
| 827 |
+
assert vocab_size % world_size == 0
|
| 828 |
+
assert config.hidden_size % world_size == 0
|
| 829 |
+
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
|
| 830 |
+
assert inner_dim % world_size == 0
|
| 831 |
+
assert config.hidden_size % config.n_head == 0
|
| 832 |
+
headdim = config.hidden_size // config.n_head
|
| 833 |
+
|
| 834 |
+
# Sometimes the word embeddings are sharded on the 0th dim, sometimes on the 1st dim.
|
| 835 |
+
# vocab_size // world_size coordinates are nonzero.
|
| 836 |
+
def combine_word_embeddings(state_dicts, state_dict, key):
|
| 837 |
+
dim = 0 if state_dicts[0][key].shape[0] == vocab_size // world_size else 1
|
| 838 |
+
state_dict[key] = torch.cat([s[key] for s in state_dicts], dim=dim)
|
| 839 |
+
|
| 840 |
+
def combine_dim(state_dicts, state_dict, key, dim=-1):
|
| 841 |
+
if key in state_dict:
|
| 842 |
+
state_dict[key] = torch.cat([s[key] for s in state_dicts], dim=dim)
|
| 843 |
+
|
| 844 |
+
def combine_qkv_headdim(state_dicts, state_dict, key):
|
| 845 |
+
n_head = config.n_head
|
| 846 |
+
n_head_kv = getattr(config, "n_head_kv", n_head)
|
| 847 |
+
if key in state_dict:
|
| 848 |
+
if n_head_kv == n_head:
|
| 849 |
+
xs = [
|
| 850 |
+
rearrange(s[key], "(three d) ... -> three d ...", three=3) for s in state_dicts
|
| 851 |
+
]
|
| 852 |
+
state_dict[key] = rearrange(torch.cat(xs, dim=1), "three d ... -> (three d) ...")
|
| 853 |
+
else:
|
| 854 |
+
n_head_each_rank = [
|
| 855 |
+
get_dim_for_local_rank(n_head, world_size, local_rank)
|
| 856 |
+
for local_rank in range(world_size)
|
| 857 |
+
]
|
| 858 |
+
n_head_kv_each_rank = [
|
| 859 |
+
get_dim_for_local_rank(n_head_kv, world_size, local_rank)
|
| 860 |
+
for local_rank in range(world_size)
|
| 861 |
+
]
|
| 862 |
+
xs = [
|
| 863 |
+
rearrange(
|
| 864 |
+
s[key],
|
| 865 |
+
"(nheadqkv headdim) ... -> nheadqkv headdim ...",
|
| 866 |
+
nheadqkv=rank_n_head + 2 * rank_n_head_kv,
|
| 867 |
+
headdim=headdim,
|
| 868 |
+
)
|
| 869 |
+
for s, rank_n_head, rank_n_head_kv in zip(
|
| 870 |
+
state_dicts, n_head_each_rank, n_head_kv_each_rank
|
| 871 |
+
)
|
| 872 |
+
]
|
| 873 |
+
wq = torch.cat([x[: n_head_each_rank[rank]] for rank, x in enumerate(xs)], dim=0)
|
| 874 |
+
wk = torch.cat(
|
| 875 |
+
[
|
| 876 |
+
x[
|
| 877 |
+
n_head_each_rank[rank] : n_head_each_rank[rank]
|
| 878 |
+
+ n_head_kv_each_rank[rank]
|
| 879 |
+
]
|
| 880 |
+
for rank, x in enumerate(xs)
|
| 881 |
+
],
|
| 882 |
+
dim=0,
|
| 883 |
+
)
|
| 884 |
+
wv = torch.cat(
|
| 885 |
+
[
|
| 886 |
+
x[n_head_each_rank[rank] + n_head_kv_each_rank[rank] :]
|
| 887 |
+
for rank, x in enumerate(xs)
|
| 888 |
+
],
|
| 889 |
+
dim=0,
|
| 890 |
+
)
|
| 891 |
+
wqkv = torch.cat(
|
| 892 |
+
[wq, wk, wv],
|
| 893 |
+
dim=0,
|
| 894 |
+
)
|
| 895 |
+
state_dict[key] = rearrange(
|
| 896 |
+
wqkv,
|
| 897 |
+
"nheadqkv headdim ... -> (nheadqkv headdim) ...",
|
| 898 |
+
)
|
| 899 |
+
|
| 900 |
+
def combine_gated_mlp(state_dicts, state_dict, key):
|
| 901 |
+
if key in state_dict:
|
| 902 |
+
xs = [rearrange(s[key], "(two d) ... -> two d ...", two=2) for s in state_dicts]
|
| 903 |
+
state_dict[key] = rearrange(torch.cat(xs, dim=1), "two d ... -> (two d) ...")
|
| 904 |
+
|
| 905 |
+
state_dict = state_dicts[0].copy() # don't modify state_dict[0] inplace
|
| 906 |
+
combine_word_embeddings(
|
| 907 |
+
state_dicts, state_dict, "transformer.embeddings.word_embeddings.weight"
|
| 908 |
+
)
|
| 909 |
+
if "lm_head.weight" in state_dict:
|
| 910 |
+
combine_word_embeddings(state_dicts, state_dict, "lm_head.weight")
|
| 911 |
+
if "transformer.embeddings.position_embeddings.weight" in state_dict:
|
| 912 |
+
combine_dim(
|
| 913 |
+
state_dicts, state_dict, "transformer.embeddings.position_embeddings.weight", -1
|
| 914 |
+
)
|
| 915 |
+
mlp_combine_fn = (
|
| 916 |
+
combine_gated_mlp
|
| 917 |
+
if config.activation_function in ["glu", "swiglu", "geglu"]
|
| 918 |
+
else partial(combine_dim, dim=0)
|
| 919 |
+
)
|
| 920 |
+
for i in range(config.num_hidden_layers):
|
| 921 |
+
combine_qkv_headdim(state_dicts, state_dict, f"transformer.layers.{i}.mixer.Wqkv.weight")
|
| 922 |
+
combine_qkv_headdim(state_dicts, state_dict, f"transformer.layers.{i}.mixer.Wqkv.bias")
|
| 923 |
+
combine_dim(state_dicts, state_dict, f"transformer.layers.{i}.mixer.out_proj.weight", -1)
|
| 924 |
+
mlp_combine_fn(state_dicts, state_dict, f"transformer.layers.{i}.mlp.fc1.weight")
|
| 925 |
+
combine_dim(state_dicts, state_dict, f"transformer.layers.{i}.mlp.fc1.bias", 0)
|
| 926 |
+
combine_dim(state_dicts, state_dict, f"transformer.layers.{i}.mlp.fc2.weight", -1)
|
| 927 |
+
return state_dict
|
| 928 |
+
|
| 929 |
+
|
| 930 |
+
def remap_state_dict_hf_gpt2(state_dict, config):
|
| 931 |
+
# Word embedding and position embedding
|
| 932 |
+
def key_mapping_pos_emb(key):
|
| 933 |
+
return re.sub(r"^wpe.", "transformer.embeddings.position_embeddings.", key)
|
| 934 |
+
|
| 935 |
+
state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items())
|
| 936 |
+
word_embeddings = state_dict.pop("wte.weight")
|
| 937 |
+
# It's possible that vocab_size is padded to be a multiple of 8, for example.
|
| 938 |
+
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
| 939 |
+
vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
|
| 940 |
+
state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
|
| 941 |
+
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
|
| 942 |
+
)
|
| 943 |
+
state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
|
| 944 |
+
|
| 945 |
+
# LayerNorm
|
| 946 |
+
def key_mapping_ln(key):
|
| 947 |
+
key = re.sub(r"^ln_f.(weight|bias)", r"transformer.ln_f.\1", key)
|
| 948 |
+
key = re.sub(r"^h.(\d+).ln_(1|2).(weight|bias)", r"transformer.layers.\1.norm\2.\3", key)
|
| 949 |
+
return key
|
| 950 |
+
|
| 951 |
+
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
|
| 952 |
+
|
| 953 |
+
# MLP
|
| 954 |
+
for d in range(config.num_hidden_layers):
|
| 955 |
+
W1 = state_dict.pop(f"h.{d}.mlp.c_fc.weight")
|
| 956 |
+
state_dict[f"transformer.layers.{d}.mlp.fc1.weight"] = W1.t()
|
| 957 |
+
W2 = state_dict.pop(f"h.{d}.mlp.c_proj.weight")
|
| 958 |
+
state_dict[f"transformer.layers.{d}.mlp.fc2.weight"] = W2.t()
|
| 959 |
+
|
| 960 |
+
def key_mapping_mlp(key):
|
| 961 |
+
key = re.sub(r"^h.(\d+).mlp.c_fc.bias", r"transformer.layers.\1.mlp.fc1.bias", key)
|
| 962 |
+
key = re.sub(r"^h.(\d+).mlp.c_proj.bias", r"transformer.layers.\1.mlp.fc2.bias", key)
|
| 963 |
+
return key
|
| 964 |
+
|
| 965 |
+
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
|
| 966 |
+
|
| 967 |
+
# Attention
|
| 968 |
+
for d in range(config.num_hidden_layers):
|
| 969 |
+
state_dict.pop(f"h.{d}.attn.bias", None) # We don't store this bias
|
| 970 |
+
Wqkv = state_dict.pop(f"h.{d}.attn.c_attn.weight")
|
| 971 |
+
state_dict[f"transformer.layers.{d}.mixer.Wqkv.weight"] = Wqkv.t()
|
| 972 |
+
Wout = state_dict.pop(f"h.{d}.attn.c_proj.weight")
|
| 973 |
+
state_dict[f"transformer.layers.{d}.mixer.out_proj.weight"] = Wout.t()
|
| 974 |
+
|
| 975 |
+
def key_mapping_attn(key):
|
| 976 |
+
key = re.sub(r"^h.(\d+).attn.c_attn.bias", r"transformer.layers.\1.mixer.Wqkv.bias", key)
|
| 977 |
+
key = re.sub(
|
| 978 |
+
r"^h.(\d+).attn.c_proj.bias", r"transformer.layers.\1.mixer.out_proj.bias", key
|
| 979 |
+
)
|
| 980 |
+
return key
|
| 981 |
+
|
| 982 |
+
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
|
| 983 |
+
|
| 984 |
+
return state_dict
|
| 985 |
+
|
| 986 |
+
|
| 987 |
+
def remap_state_dict_megatron(state_dict, config):
|
| 988 |
+
def key_mapping_transformer(key):
|
| 989 |
+
key = re.sub(r"^language_model.encoder.", "transformer.", key)
|
| 990 |
+
key = re.sub(r"^language_model.", "transformer.", key)
|
| 991 |
+
return key
|
| 992 |
+
|
| 993 |
+
state_dict = OrderedDict((key_mapping_transformer(k), v) for k, v in state_dict.items())
|
| 994 |
+
|
| 995 |
+
# Word embedding and position embedding
|
| 996 |
+
def key_mapping_pos_emb(key):
|
| 997 |
+
return re.sub(r"^wpe.", "transformer.embeddings.position_embeddings.", key)
|
| 998 |
+
|
| 999 |
+
state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items())
|
| 1000 |
+
word_embeddings = state_dict.pop("transformer.embedding.word_embeddings.weight")
|
| 1001 |
+
# It's possible that vocab_size is padded to be a multiple of 8, for example.
|
| 1002 |
+
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
| 1003 |
+
vocab_size = (
|
| 1004 |
+
math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple) * pad_vocab_size_multiple
|
| 1005 |
+
)
|
| 1006 |
+
state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
|
| 1007 |
+
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
|
| 1008 |
+
)
|
| 1009 |
+
state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
|
| 1010 |
+
|
| 1011 |
+
# LayerNorm
|
| 1012 |
+
def key_mapping_ln(key):
|
| 1013 |
+
key = re.sub(r"^transformer.final_layernorm.(weight|bias)", r"transformer.ln_f.\1", key)
|
| 1014 |
+
key = re.sub(
|
| 1015 |
+
r"^transformer.layers.(\d+).input_layernorm.(weight|bias)",
|
| 1016 |
+
r"transformer.layers.\1.norm1.\2",
|
| 1017 |
+
key,
|
| 1018 |
+
)
|
| 1019 |
+
key = re.sub(
|
| 1020 |
+
r"^transformer.layers.(\d+).post_attention_layernorm.(weight|bias)",
|
| 1021 |
+
r"transformer.layers.\1.norm2.\2",
|
| 1022 |
+
key,
|
| 1023 |
+
)
|
| 1024 |
+
return key
|
| 1025 |
+
|
| 1026 |
+
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
|
| 1027 |
+
|
| 1028 |
+
# MLP
|
| 1029 |
+
def key_mapping_mlp(key):
|
| 1030 |
+
key = re.sub(
|
| 1031 |
+
r"^transformer.layers.(\d+).mlp.dense_h_to_4h.(weight|bias)",
|
| 1032 |
+
r"transformer.layers.\1.mlp.fc1.\2",
|
| 1033 |
+
key,
|
| 1034 |
+
)
|
| 1035 |
+
key = re.sub(
|
| 1036 |
+
r"^transformer.layers.(\d+).mlp.dense_4h_to_h.(weight|bias)",
|
| 1037 |
+
r"transformer.layers.\1.mlp.fc2.\2",
|
| 1038 |
+
key,
|
| 1039 |
+
)
|
| 1040 |
+
return key
|
| 1041 |
+
|
| 1042 |
+
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
|
| 1043 |
+
|
| 1044 |
+
# Attention
|
| 1045 |
+
def key_mapping_attn(key):
|
| 1046 |
+
key = re.sub(
|
| 1047 |
+
r"^transformer.layers.(\d+).self_attention.rotary_emb.inv_freq",
|
| 1048 |
+
r"transformer.layers.\1.mixer.rotary_emb.inv_freq",
|
| 1049 |
+
key,
|
| 1050 |
+
)
|
| 1051 |
+
key = re.sub(
|
| 1052 |
+
r"^transformer.layers.(\d+).self_attention.query_key_value.(weight|bias)",
|
| 1053 |
+
r"transformer.layers.\1.mixer.Wqkv.\2",
|
| 1054 |
+
key,
|
| 1055 |
+
)
|
| 1056 |
+
key = re.sub(
|
| 1057 |
+
r"^transformer.layers.(\d+).self_attention.dense.(weight|bias)",
|
| 1058 |
+
r"transformer.layers.\1.mixer.out_proj.\2",
|
| 1059 |
+
key,
|
| 1060 |
+
)
|
| 1061 |
+
return key
|
| 1062 |
+
|
| 1063 |
+
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
|
| 1064 |
+
# Megatron stores Wqkv as ((nheads 3 headdim), hidden_dim)
|
| 1065 |
+
# while we store Wqkv as ((3 nheads headdim), hidden_dim)
|
| 1066 |
+
headdim = config.hidden_size // config.num_attention_heads
|
| 1067 |
+
for d in range(config.num_hidden_layers):
|
| 1068 |
+
Wqkv = state_dict.pop(f"transformer.layers.{d}.mixer.Wqkv.weight")
|
| 1069 |
+
state_dict[f"transformer.layers.{d}.mixer.Wqkv.weight"] = rearrange(
|
| 1070 |
+
Wqkv,
|
| 1071 |
+
"(nheads three headdim) ... -> (three nheads headdim) ...",
|
| 1072 |
+
three=3,
|
| 1073 |
+
headdim=headdim,
|
| 1074 |
+
)
|
| 1075 |
+
bqkv = state_dict.pop(f"transformer.layers.{d}.mixer.Wqkv.bias")
|
| 1076 |
+
state_dict[f"transformer.layers.{d}.mixer.Wqkv.bias"] = rearrange(
|
| 1077 |
+
bqkv, "(nheads three headdim) -> (three nheads headdim)", three=3, headdim=headdim
|
| 1078 |
+
)
|
| 1079 |
+
|
| 1080 |
+
return state_dict
|
.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/gpt_neox.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023, Tri Dao.
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
import re
|
| 5 |
+
from collections import OrderedDict
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from einops import rearrange
|
| 10 |
+
from transformers import GPT2Config, GPTNeoXConfig
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def remap_state_dict_hf_gpt_neox(state_dict, config):
|
| 14 |
+
def key_mapping_layers(key):
|
| 15 |
+
return re.sub(r"^gpt_neox.", "transformer.", key)
|
| 16 |
+
|
| 17 |
+
state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
|
| 18 |
+
# Word embedding
|
| 19 |
+
def key_mapping_emb(key):
|
| 20 |
+
return re.sub(r"^transformer.embed_in.", "transformer.embeddings.word_embeddings.", key)
|
| 21 |
+
|
| 22 |
+
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
|
| 23 |
+
word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight")
|
| 24 |
+
# It's possible that vocab_size is padded to be a multiple of 8, for example.
|
| 25 |
+
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
| 26 |
+
vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
|
| 27 |
+
state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
|
| 28 |
+
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
|
| 29 |
+
)
|
| 30 |
+
if getattr(config, "tie_word_embeddings", False):
|
| 31 |
+
state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
|
| 32 |
+
else:
|
| 33 |
+
output_embeddings = state_dict.pop("embed_out.weight")
|
| 34 |
+
# It's possible that vocab_size is padded to be a multiple of 8, for example.
|
| 35 |
+
state_dict["lm_head.weight"] = F.pad(
|
| 36 |
+
output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
# LayerNorm
|
| 40 |
+
def key_mapping_ln(key):
|
| 41 |
+
key = re.sub(r"^transformer.final_layer_norm.", r"transformer.ln_f.", key)
|
| 42 |
+
key = re.sub(
|
| 43 |
+
r"^transformer.layers.(\d+).input_layernorm.", r"transformer.layers.\1.norm1.", key
|
| 44 |
+
)
|
| 45 |
+
key = re.sub(
|
| 46 |
+
r"^transformer.layers.(\d+).post_attention_layernorm.",
|
| 47 |
+
r"transformer.layers.\1.norm2.",
|
| 48 |
+
key,
|
| 49 |
+
)
|
| 50 |
+
return key
|
| 51 |
+
|
| 52 |
+
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
|
| 53 |
+
|
| 54 |
+
# MLP
|
| 55 |
+
def key_mapping_mlp(key):
|
| 56 |
+
key = re.sub(
|
| 57 |
+
r"^transformer.layers.(\d+).mlp.dense_h_to_4h.", r"transformer.layers.\1.mlp.fc1.", key
|
| 58 |
+
)
|
| 59 |
+
key = re.sub(
|
| 60 |
+
r"^transformer.layers.(\d+).mlp.dense_4h_to_h.", r"transformer.layers.\1.mlp.fc2.", key
|
| 61 |
+
)
|
| 62 |
+
return key
|
| 63 |
+
|
| 64 |
+
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
|
| 65 |
+
|
| 66 |
+
# Attention
|
| 67 |
+
for l in range(config.n_layer):
|
| 68 |
+
# We don't store these biases
|
| 69 |
+
state_dict.pop(f"transformer.layers.{l}.attention.bias")
|
| 70 |
+
state_dict.pop(f"transformer.layers.{l}.attention.masked_bias")
|
| 71 |
+
# We don't store these
|
| 72 |
+
state_dict.pop(f"transformer.layers.{l}.attention.rotary_emb.inv_freq", None)
|
| 73 |
+
# GPT-NeoX stores Wqkv as ((nheads 3 headdim), hidden_dim)
|
| 74 |
+
# while we store Wqkv as ((3 nheads headdim), hidden_dim)
|
| 75 |
+
headdim = config.hidden_size // config.num_attention_heads
|
| 76 |
+
Wqkv = state_dict.pop(f"transformer.layers.{l}.attention.query_key_value.weight")
|
| 77 |
+
state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = rearrange(
|
| 78 |
+
Wqkv,
|
| 79 |
+
"(nheads three headdim) ... -> (three nheads headdim) ...",
|
| 80 |
+
three=3,
|
| 81 |
+
headdim=headdim,
|
| 82 |
+
)
|
| 83 |
+
bqkv = state_dict.pop(f"transformer.layers.{l}.attention.query_key_value.bias")
|
| 84 |
+
state_dict[f"transformer.layers.{l}.mixer.Wqkv.bias"] = rearrange(
|
| 85 |
+
bqkv, "(nheads three headdim) -> (three nheads headdim)", three=3, headdim=headdim
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
def key_mapping_attn(key):
|
| 89 |
+
key = re.sub(
|
| 90 |
+
r"^transformer.layers.(\d+).attention.dense.",
|
| 91 |
+
r"transformer.layers.\1.mixer.out_proj.",
|
| 92 |
+
key,
|
| 93 |
+
)
|
| 94 |
+
return key
|
| 95 |
+
|
| 96 |
+
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
|
| 97 |
+
|
| 98 |
+
return state_dict
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def gpt_neox_config_to_gpt2_config(gpt_neox_config: GPTNeoXConfig) -> GPT2Config:
|
| 102 |
+
assert gpt_neox_config.rotary_emb_base == 10000
|
| 103 |
+
return GPT2Config(
|
| 104 |
+
vocab_size=gpt_neox_config.vocab_size,
|
| 105 |
+
n_positions=0, # No absolute position embedding
|
| 106 |
+
n_embd=gpt_neox_config.hidden_size,
|
| 107 |
+
n_layer=gpt_neox_config.num_hidden_layers,
|
| 108 |
+
n_head=gpt_neox_config.num_attention_heads,
|
| 109 |
+
n_inner=gpt_neox_config.intermediate_size,
|
| 110 |
+
activation_function=gpt_neox_config.hidden_act,
|
| 111 |
+
resid_pdrop=0.0, # No dropout
|
| 112 |
+
embd_pdrop=0.0,
|
| 113 |
+
attn_pdrop=0.0,
|
| 114 |
+
layer_norm_epsilon=gpt_neox_config.layer_norm_eps,
|
| 115 |
+
initializer_range=gpt_neox_config.initializer_range,
|
| 116 |
+
bos_token_id=gpt_neox_config.bos_token_id,
|
| 117 |
+
eos_token_id=gpt_neox_config.eos_token_id,
|
| 118 |
+
# These are new arguments not in the original GPT2Config
|
| 119 |
+
prenorm=True,
|
| 120 |
+
parallel_block=gpt_neox_config.use_parallel_residual,
|
| 121 |
+
parallel_block_tied_norm=False,
|
| 122 |
+
rotary_emb_fraction=gpt_neox_config.rotary_pct,
|
| 123 |
+
tie_word_embeddings=gpt_neox_config.tie_word_embeddings,
|
| 124 |
+
)
|
.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/gptj.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023, Tri Dao.
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
import re
|
| 5 |
+
from collections import OrderedDict
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from transformers import GPT2Config, GPTJConfig
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def remap_state_dict_hf_gptj(state_dict, config):
|
| 13 |
+
def key_mapping_layers(key):
|
| 14 |
+
return re.sub(r"^transformer.h.", "transformer.layers.", key)
|
| 15 |
+
|
| 16 |
+
state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
|
| 17 |
+
# Word embedding
|
| 18 |
+
def key_mapping_emb(key):
|
| 19 |
+
return re.sub(r"^transformer.wte.", "transformer.embeddings.word_embeddings.", key)
|
| 20 |
+
|
| 21 |
+
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
|
| 22 |
+
word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight")
|
| 23 |
+
# It's possible that vocab_size is padded to be a multiple of 8, for example.
|
| 24 |
+
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
| 25 |
+
vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
|
| 26 |
+
state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
|
| 27 |
+
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
|
| 28 |
+
)
|
| 29 |
+
if getattr(config, "tie_word_embeddings"):
|
| 30 |
+
state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
|
| 31 |
+
else:
|
| 32 |
+
output_embeddings = state_dict.pop("lm_head.weight")
|
| 33 |
+
# It's possible that vocab_size is padded to be a multiple of 8, for example.
|
| 34 |
+
state_dict["lm_head.weight"] = F.pad(
|
| 35 |
+
output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
|
| 36 |
+
)
|
| 37 |
+
output_embeddings_bias = state_dict.pop("lm_head.bias")
|
| 38 |
+
state_dict["lm_head.bias"] = F.pad(
|
| 39 |
+
output_embeddings_bias, (0, vocab_size - output_embeddings_bias.shape[0])
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# LayerNorm
|
| 43 |
+
def key_mapping_ln(key):
|
| 44 |
+
return re.sub(r"^transformer.layers.(\d+).ln_1.", r"transformer.layers.\1.norm1.", key)
|
| 45 |
+
|
| 46 |
+
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
|
| 47 |
+
|
| 48 |
+
# MLP
|
| 49 |
+
def key_mapping_mlp(key):
|
| 50 |
+
key = re.sub(
|
| 51 |
+
r"^transformer.layers.(\d+).mlp.fc_in.", r"transformer.layers.\1.mlp.fc1.", key
|
| 52 |
+
)
|
| 53 |
+
key = re.sub(
|
| 54 |
+
r"^transformer.layers.(\d+).mlp.fc_out.", r"transformer.layers.\1.mlp.fc2.", key
|
| 55 |
+
)
|
| 56 |
+
return key
|
| 57 |
+
|
| 58 |
+
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
|
| 59 |
+
|
| 60 |
+
# Attention
|
| 61 |
+
for l in range(config.n_layer):
|
| 62 |
+
Wq = state_dict.pop(f"transformer.layers.{l}.attn.q_proj.weight")
|
| 63 |
+
Wk = state_dict.pop(f"transformer.layers.{l}.attn.k_proj.weight")
|
| 64 |
+
Wv = state_dict.pop(f"transformer.layers.{l}.attn.v_proj.weight")
|
| 65 |
+
state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat([Wq, Wk, Wv], dim=0)
|
| 66 |
+
# We don't store these biases
|
| 67 |
+
state_dict.pop(f"transformer.layers.{l}.attn.bias")
|
| 68 |
+
state_dict.pop(f"transformer.layers.{l}.attn.masked_bias")
|
| 69 |
+
|
| 70 |
+
def key_mapping_attn(key):
|
| 71 |
+
return re.sub(
|
| 72 |
+
r"^transformer.layers.(\d+).attn.out_proj.",
|
| 73 |
+
r"transformer.layers.\1.mixer.out_proj.",
|
| 74 |
+
key,
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
|
| 78 |
+
|
| 79 |
+
return state_dict
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def gptj_config_to_gpt2_config(gptj_config: GPTJConfig) -> GPT2Config:
|
| 83 |
+
headdim = gptj_config.n_embd // gptj_config.n_head
|
| 84 |
+
return GPT2Config(
|
| 85 |
+
vocab_size=gptj_config.vocab_size,
|
| 86 |
+
n_positions=0, # No absolute position embedding
|
| 87 |
+
n_embd=gptj_config.n_embd,
|
| 88 |
+
n_layer=gptj_config.n_layer,
|
| 89 |
+
n_head=gptj_config.n_head,
|
| 90 |
+
n_inner=gptj_config.n_inner,
|
| 91 |
+
activation_function=gptj_config.activation_function,
|
| 92 |
+
resid_pdrop=gptj_config.resid_pdrop,
|
| 93 |
+
embd_pdrop=gptj_config.embd_pdrop,
|
| 94 |
+
attn_pdrop=gptj_config.attn_pdrop,
|
| 95 |
+
layer_norm_epsilon=gptj_config.layer_norm_epsilon,
|
| 96 |
+
initializer_range=gptj_config.initializer_range,
|
| 97 |
+
bos_token_id=gptj_config.bos_token_id,
|
| 98 |
+
eos_token_id=gptj_config.eos_token_id,
|
| 99 |
+
# These are new arguments not in the original GPT2Config
|
| 100 |
+
prenorm=True,
|
| 101 |
+
parallel_block=True,
|
| 102 |
+
parallel_block_tied_norm=True,
|
| 103 |
+
rotary_emb_fraction=gptj_config.rotary_dim / headdim,
|
| 104 |
+
rotary_emb_interleaved=True,
|
| 105 |
+
tie_word_embeddings=False,
|
| 106 |
+
qkv_proj_bias=False,
|
| 107 |
+
out_proj_bias=False,
|
| 108 |
+
lm_head_bias=True,
|
| 109 |
+
)
|
.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/llama.py
ADDED
|
@@ -0,0 +1,422 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023, Tri Dao.
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import math
|
| 5 |
+
import os
|
| 6 |
+
import re
|
| 7 |
+
from collections import OrderedDict
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Dict, List, Union
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from sentencepiece import SentencePieceProcessor
|
| 14 |
+
from transformers import GPT2Config, LlamaConfig
|
| 15 |
+
|
| 16 |
+
from einops import rearrange
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def remap_state_dict_meta_llama(
|
| 20 |
+
state_dict: Dict[str, torch.Tensor], config: GPT2Config
|
| 21 |
+
) -> Dict[str, torch.Tensor]:
|
| 22 |
+
"""Convert the state_dict in Meta format to standard GPT format.
|
| 23 |
+
|
| 24 |
+
This function modifies state_dict in place.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def key_mapping_layers(key):
|
| 28 |
+
return f"transformer.{key}" if not key.startswith("output.") else key
|
| 29 |
+
|
| 30 |
+
state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
|
| 31 |
+
|
| 32 |
+
# Word embedding
|
| 33 |
+
def key_mapping_emb(key):
|
| 34 |
+
return re.sub(
|
| 35 |
+
r"^transformer.tok_embeddings.", "transformer.embeddings.word_embeddings.", key
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
|
| 39 |
+
word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight")
|
| 40 |
+
# It's possible that vocab_size is padded to be a multiple of 8, for example.
|
| 41 |
+
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
| 42 |
+
vocab_size = (
|
| 43 |
+
math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple) * pad_vocab_size_multiple
|
| 44 |
+
)
|
| 45 |
+
state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
|
| 46 |
+
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
|
| 47 |
+
)
|
| 48 |
+
if getattr(config, "tie_word_embeddings"):
|
| 49 |
+
state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
|
| 50 |
+
else:
|
| 51 |
+
output_embeddings = state_dict.pop("output.weight")
|
| 52 |
+
# Need to recompute vocab_size since LLaMa shards the word embeddings and output embeddings
|
| 53 |
+
# differently.
|
| 54 |
+
vocab_size = (
|
| 55 |
+
math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple)
|
| 56 |
+
* pad_vocab_size_multiple
|
| 57 |
+
)
|
| 58 |
+
# It's possible that vocab_size is padded to be a multiple of 8, for example.
|
| 59 |
+
state_dict["lm_head.weight"] = F.pad(
|
| 60 |
+
output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
# LayerNorm
|
| 64 |
+
def key_mapping_ln(key):
|
| 65 |
+
key = re.sub(r"^transformer.norm.", r"transformer.ln_f.", key)
|
| 66 |
+
key = re.sub(
|
| 67 |
+
r"^transformer.layers.(\d+).attention_norm.",
|
| 68 |
+
r"transformer.layers.\1.norm1.",
|
| 69 |
+
key,
|
| 70 |
+
)
|
| 71 |
+
key = re.sub(r"^transformer.layers.(\d+).ffn_norm.", r"transformer.layers.\1.norm2.", key)
|
| 72 |
+
return key
|
| 73 |
+
|
| 74 |
+
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
|
| 75 |
+
|
| 76 |
+
# MLP
|
| 77 |
+
for l in range(config.n_layer):
|
| 78 |
+
w1 = state_dict.pop(f"transformer.layers.{l}.feed_forward.w1.weight")
|
| 79 |
+
w3 = state_dict.pop(f"transformer.layers.{l}.feed_forward.w3.weight")
|
| 80 |
+
# Our ordering is different
|
| 81 |
+
state_dict[f"transformer.layers.{l}.mlp.fc1.weight"] = torch.cat([w3, w1], dim=0)
|
| 82 |
+
|
| 83 |
+
def key_mapping_mlp(key):
|
| 84 |
+
return re.sub(
|
| 85 |
+
r"^transformer.layers.(\d+).feed_forward.w2.",
|
| 86 |
+
r"transformer.layers.\1.mlp.fc2.",
|
| 87 |
+
key,
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
|
| 91 |
+
|
| 92 |
+
# Attention
|
| 93 |
+
for l in range(config.n_layer):
|
| 94 |
+
Wq = state_dict.pop(f"transformer.layers.{l}.attention.wq.weight")
|
| 95 |
+
Wk = state_dict.pop(f"transformer.layers.{l}.attention.wk.weight")
|
| 96 |
+
Wv = state_dict.pop(f"transformer.layers.{l}.attention.wv.weight")
|
| 97 |
+
state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat([Wq, Wk, Wv], dim=0)
|
| 98 |
+
# We don't store these
|
| 99 |
+
state_dict.pop(f"transformer.layers.{l}.attention.inner_attention.rope.freqs", None)
|
| 100 |
+
|
| 101 |
+
def key_mapping_attn(key):
|
| 102 |
+
return re.sub(
|
| 103 |
+
r"^transformer.layers.(\d+).attention.wo.",
|
| 104 |
+
r"transformer.layers.\1.mixer.out_proj.",
|
| 105 |
+
key,
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
|
| 109 |
+
|
| 110 |
+
state_dict.pop("transformer.rope.freqs", None)
|
| 111 |
+
|
| 112 |
+
return state_dict
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def remap_state_dict_hf_llama(
|
| 116 |
+
state_dict: Dict[str, torch.Tensor], config: GPT2Config
|
| 117 |
+
) -> Dict[str, torch.Tensor]:
|
| 118 |
+
"""Convert the state_dict in Hugging Face format to standard GPT format.
|
| 119 |
+
|
| 120 |
+
This function modifies state_dict in place.
|
| 121 |
+
"""
|
| 122 |
+
|
| 123 |
+
# Embedding
|
| 124 |
+
def key_mapping_emb(key):
|
| 125 |
+
return re.sub(r"^model.embed_tokens.", "transformer.embeddings.word_embeddings.", key)
|
| 126 |
+
|
| 127 |
+
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
|
| 128 |
+
word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight")
|
| 129 |
+
# It's possible that vocab_size is padded to be a multiple of 8, for example.
|
| 130 |
+
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
| 131 |
+
vocab_size = (
|
| 132 |
+
math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple) * pad_vocab_size_multiple
|
| 133 |
+
)
|
| 134 |
+
state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
|
| 135 |
+
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
# LM head
|
| 139 |
+
if getattr(config, "tie_word_embeddings"):
|
| 140 |
+
state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
|
| 141 |
+
else:
|
| 142 |
+
output_embeddings = state_dict.pop("lm_head.weight")
|
| 143 |
+
# Need to recompute vocab_size since LLaMa shards the word embeddings and output embeddings
|
| 144 |
+
# differently.
|
| 145 |
+
vocab_size = (
|
| 146 |
+
math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple)
|
| 147 |
+
* pad_vocab_size_multiple
|
| 148 |
+
)
|
| 149 |
+
# It's possible that vocab_size is padded to be a multiple of 8, for example.
|
| 150 |
+
state_dict["lm_head.weight"] = F.pad(
|
| 151 |
+
output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
# MLP
|
| 155 |
+
for l in range(config.n_layer):
|
| 156 |
+
# Fusing weights this way based on difference in the following:
|
| 157 |
+
# https://github.com/huggingface/transformers/blob/b42010bb1d3cbf262d27e0a328661885be46dfdb/src/transformers/models/llama/modeling_llama.py#L220
|
| 158 |
+
# https://github.com/Dao-AILab/flash-attention/blob/c60851a8253257eb970e06a022c82517a8033e8c/flash_attn/modules/mlp.py#L115
|
| 159 |
+
w1 = state_dict.pop(f"model.layers.{l}.mlp.gate_proj.weight")
|
| 160 |
+
w3 = state_dict.pop(f"model.layers.{l}.mlp.up_proj.weight")
|
| 161 |
+
state_dict[f"transformer.layers.{l}.mlp.fc1.weight"] = torch.cat([w3, w1], dim=0)
|
| 162 |
+
|
| 163 |
+
def key_mapping_mlp(key):
|
| 164 |
+
return re.sub(
|
| 165 |
+
r"^model.layers.(\d+).mlp.down_proj.",
|
| 166 |
+
r"transformer.layers.\1.mlp.fc2.",
|
| 167 |
+
key,
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
|
| 171 |
+
|
| 172 |
+
# LayerNorm
|
| 173 |
+
def key_mapping_ln(key):
|
| 174 |
+
key = re.sub(r"^model.norm.", r"transformer.ln_f.", key)
|
| 175 |
+
key = re.sub(
|
| 176 |
+
r"^model.layers.(\d+).input_layernorm.",
|
| 177 |
+
r"transformer.layers.\1.norm1.",
|
| 178 |
+
key,
|
| 179 |
+
)
|
| 180 |
+
key = re.sub(
|
| 181 |
+
r"^model.layers.(\d+).post_attention_layernorm.",
|
| 182 |
+
r"transformer.layers.\1.norm2.",
|
| 183 |
+
key,
|
| 184 |
+
)
|
| 185 |
+
return key
|
| 186 |
+
|
| 187 |
+
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
|
| 188 |
+
|
| 189 |
+
def inv_permute(w):
|
| 190 |
+
# Inverse of permute implemented in:
|
| 191 |
+
# https://github.com/huggingface/transformers/blob/b42010bb1d3cbf262d27e0a328661885be46dfdb/src/transformers/models/llama/convert_llama_weights_to_hf.py#L114
|
| 192 |
+
return rearrange(
|
| 193 |
+
w, "(h two d) n -> (h d two) n", d=config.n_embd // config.n_head // 2, two=2
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
# Attention
|
| 197 |
+
for l in range(config.n_layer):
|
| 198 |
+
Wq = state_dict.pop(f"model.layers.{l}.self_attn.q_proj.weight")
|
| 199 |
+
Wk = state_dict.pop(f"model.layers.{l}.self_attn.k_proj.weight")
|
| 200 |
+
Wv = state_dict.pop(f"model.layers.{l}.self_attn.v_proj.weight")
|
| 201 |
+
|
| 202 |
+
state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat(
|
| 203 |
+
[inv_permute(Wq), inv_permute(Wk), Wv], dim=0
|
| 204 |
+
)
|
| 205 |
+
# We don't store these
|
| 206 |
+
state_dict.pop(f"model.layers.{l}.self_attn.rotary_emb.inv_freq", None)
|
| 207 |
+
|
| 208 |
+
def key_mapping_attn(key):
|
| 209 |
+
return re.sub(
|
| 210 |
+
r"^model.layers.(\d+).self_attn.o_proj.",
|
| 211 |
+
r"transformer.layers.\1.mixer.out_proj.",
|
| 212 |
+
key,
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
|
| 216 |
+
return state_dict
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def inv_remap_state_dict_hf_llama(
|
| 220 |
+
state_dict: Dict[str, torch.Tensor], config: GPT2Config
|
| 221 |
+
) -> Dict[str, torch.Tensor]:
|
| 222 |
+
"""Convert the state_dict in standard GPT format to Hugging Face format.
|
| 223 |
+
|
| 224 |
+
This function is meant to be the inverse of remap_state_dict_hf_llama, up to a
|
| 225 |
+
multiplier pad in the embedding and lm_head. That is if the original embedding
|
| 226 |
+
isn't a multiple of pad_vocab_size_multiple, then
|
| 227 |
+
inv_remap_state_dict_hf_llama(remap_state_dict_hf_llama(state_dict)) != state_dict.
|
| 228 |
+
|
| 229 |
+
This function modifies state_dict in place.
|
| 230 |
+
"""
|
| 231 |
+
|
| 232 |
+
# Embedding
|
| 233 |
+
def key_mapping_emb(key):
|
| 234 |
+
return re.sub(r"^transformer.embeddings.word_embeddings.", "model.embed_tokens.", key)
|
| 235 |
+
|
| 236 |
+
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
|
| 237 |
+
word_embeddings = state_dict.pop("model.embed_tokens.weight")
|
| 238 |
+
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
| 239 |
+
vocab_size = (
|
| 240 |
+
math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple) * pad_vocab_size_multiple
|
| 241 |
+
)
|
| 242 |
+
state_dict["model.embed_tokens.weight"] = F.pad(
|
| 243 |
+
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
# LM head
|
| 247 |
+
if getattr(config, "tie_word_embeddings"):
|
| 248 |
+
state_dict["lm_head.weight"] = state_dict["model.embed_tokens.weight"]
|
| 249 |
+
else:
|
| 250 |
+
output_embeddings = state_dict.pop("lm_head.weight")
|
| 251 |
+
vocab_size = (
|
| 252 |
+
math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple)
|
| 253 |
+
* pad_vocab_size_multiple
|
| 254 |
+
)
|
| 255 |
+
state_dict["lm_head.weight"] = F.pad(
|
| 256 |
+
output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
# MLP
|
| 260 |
+
for l in range(config.n_layer):
|
| 261 |
+
w3, w1 = torch.chunk(
|
| 262 |
+
state_dict.pop(f"transformer.layers.{l}.mlp.fc1.weight"), chunks=2, dim=0
|
| 263 |
+
)
|
| 264 |
+
state_dict[f"model.layers.{l}.mlp.gate_proj.weight"] = w1
|
| 265 |
+
state_dict[f"model.layers.{l}.mlp.up_proj.weight"] = w3
|
| 266 |
+
|
| 267 |
+
def key_mapping_mlp(key):
|
| 268 |
+
return re.sub(
|
| 269 |
+
r"^transformer.layers.(\d+).mlp.fc2.",
|
| 270 |
+
r"model.layers.\1.mlp.down_proj.",
|
| 271 |
+
key,
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
|
| 275 |
+
|
| 276 |
+
# LayerNorm
|
| 277 |
+
def key_mapping_ln(key):
|
| 278 |
+
key = re.sub(r"^transformer.ln_f.", r"model.norm.", key)
|
| 279 |
+
key = re.sub(
|
| 280 |
+
r"^transformer.layers.(\d+).norm1.",
|
| 281 |
+
r"model.layers.\1.input_layernorm.",
|
| 282 |
+
key,
|
| 283 |
+
)
|
| 284 |
+
key = re.sub(
|
| 285 |
+
r"^transformer.layers.(\d+).norm2.",
|
| 286 |
+
r"model.layers.\1.post_attention_layernorm.",
|
| 287 |
+
key,
|
| 288 |
+
)
|
| 289 |
+
return key
|
| 290 |
+
|
| 291 |
+
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
|
| 292 |
+
|
| 293 |
+
def permute(w):
|
| 294 |
+
return rearrange(
|
| 295 |
+
w, "(h d two) n -> (h two d) n", d=config.n_embd // config.n_head // 2, two=2
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
n_head = config.n_head
|
| 299 |
+
n_head_kv = getattr(config, "n_head_kv", n_head)
|
| 300 |
+
|
| 301 |
+
embed_dim = config.hidden_size
|
| 302 |
+
head_dim = embed_dim // n_head
|
| 303 |
+
|
| 304 |
+
q_dim = n_head * head_dim
|
| 305 |
+
k_dim = v_dim = n_head_kv * head_dim
|
| 306 |
+
|
| 307 |
+
# Attention
|
| 308 |
+
for l in range(config.n_layer):
|
| 309 |
+
Wqkv = state_dict.pop(f"transformer.layers.{l}.mixer.Wqkv.weight")
|
| 310 |
+
Wq = Wqkv[:q_dim]
|
| 311 |
+
Wk = Wqkv[q_dim : q_dim + k_dim]
|
| 312 |
+
Wv = Wqkv[q_dim + k_dim : q_dim + k_dim + v_dim]
|
| 313 |
+
state_dict[f"model.layers.{l}.self_attn.q_proj.weight"] = permute(Wq)
|
| 314 |
+
state_dict[f"model.layers.{l}.self_attn.k_proj.weight"] = permute(Wk)
|
| 315 |
+
state_dict[f"model.layers.{l}.self_attn.v_proj.weight"] = Wv
|
| 316 |
+
state_dict.pop(f"transformer.layers.{l}.attention.inner_attention.rope.freqs", None)
|
| 317 |
+
|
| 318 |
+
def key_mapping_attn(key):
|
| 319 |
+
return re.sub(
|
| 320 |
+
r"^transformer.layers.(\d+).mixer.out_proj.",
|
| 321 |
+
r"model.layers.\1.self_attn.o_proj.",
|
| 322 |
+
key,
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
|
| 326 |
+
return state_dict
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
def config_from_meta_checkpoint(
|
| 330 |
+
checkpoint_path: Union[str, os.PathLike], model_name: str
|
| 331 |
+
) -> LlamaConfig:
|
| 332 |
+
"""Load a LlamaConfig from a checkpoint path."""
|
| 333 |
+
with open(Path(checkpoint_path) / model_name / "params.json") as f:
|
| 334 |
+
params = json.load(f)
|
| 335 |
+
config = LlamaConfig(
|
| 336 |
+
hidden_size=params["dim"],
|
| 337 |
+
intermediate_size=None,
|
| 338 |
+
num_attention_heads=params["n_heads"],
|
| 339 |
+
num_hidden_layers=params["n_layers"],
|
| 340 |
+
rms_norm_eps=params["norm_eps"],
|
| 341 |
+
num_key_value_heads=params.get("n_kv_heads", None),
|
| 342 |
+
)
|
| 343 |
+
multiple_of = params.get("multiple_of", 1)
|
| 344 |
+
ffn_dim_multiplier = params.get("ffn_dim_multiplier", None)
|
| 345 |
+
|
| 346 |
+
# Compute the hidden dimension of the MLP
|
| 347 |
+
# https://github.com/facebookresearch/llama/blob/1a240688810f8036049e8da36b073f63d2ac552c/llama/model.py#L224
|
| 348 |
+
intermediate_size = 4 * config.hidden_size
|
| 349 |
+
# https://github.com/facebookresearch/llama/blob/1a240688810f8036049e8da36b073f63d2ac552c/llama/model.py#L195-L199
|
| 350 |
+
intermediate_size = int(2 * intermediate_size / 3)
|
| 351 |
+
# custom dim factor multiplier
|
| 352 |
+
if ffn_dim_multiplier is not None:
|
| 353 |
+
intermediate_size = int(ffn_dim_multiplier * intermediate_size)
|
| 354 |
+
intermediate_size = multiple_of * ((intermediate_size + multiple_of - 1) // multiple_of)
|
| 355 |
+
|
| 356 |
+
config.intermediate_size = intermediate_size
|
| 357 |
+
if "rope_theta" in params:
|
| 358 |
+
config.rotary_emb_base = params["rope_theta"]
|
| 359 |
+
config.vocab_size = 32000
|
| 360 |
+
# some CodeLLaMa have vocab_size 32000, some 32016
|
| 361 |
+
# Sadly it's not specified in the `params.json` file :(
|
| 362 |
+
tokenizer = Path(checkpoint_path) / model_name / "tokenizer.model"
|
| 363 |
+
if tokenizer.is_file():
|
| 364 |
+
config.vocab_size = SentencePieceProcessor(str(tokenizer)).vocab_size()
|
| 365 |
+
return config
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
def config_from_hf_checkpoint(
|
| 369 |
+
checkpoint_path: Union[str, os.PathLike], model_name: str
|
| 370 |
+
) -> LlamaConfig:
|
| 371 |
+
return LlamaConfig.from_pretrained(Path(checkpoint_path) / f"{model_name}-hf" / "config.json")
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
def config_from_checkpoint(
|
| 375 |
+
checkpoint_path: Union[str, os.PathLike], model_name: str, checkpoint_format="meta"
|
| 376 |
+
) -> LlamaConfig:
|
| 377 |
+
if checkpoint_format == "meta":
|
| 378 |
+
return config_from_meta_checkpoint(checkpoint_path, model_name)
|
| 379 |
+
else:
|
| 380 |
+
return config_from_hf_checkpoint(checkpoint_path, model_name)
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
def state_dicts_from_checkpoint(
|
| 384 |
+
checkpoint_path: Union[str, os.PathLike], model_name: str
|
| 385 |
+
) -> List[dict]:
|
| 386 |
+
# Need to sort, otherwise we mess up the ordering and the weights are wrong
|
| 387 |
+
return [
|
| 388 |
+
torch.load(path, map_location="cpu")
|
| 389 |
+
for path in sorted((Path(checkpoint_path) / model_name).glob("consolidated.*.pth"))
|
| 390 |
+
]
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
def llama_config_to_gpt2_config(llama_config: LlamaConfig) -> GPT2Config:
|
| 394 |
+
return GPT2Config(
|
| 395 |
+
vocab_size=llama_config.vocab_size,
|
| 396 |
+
n_positions=0, # No absolute position embedding
|
| 397 |
+
n_embd=llama_config.hidden_size,
|
| 398 |
+
n_layer=llama_config.num_hidden_layers,
|
| 399 |
+
n_head=llama_config.num_attention_heads,
|
| 400 |
+
n_inner=llama_config.intermediate_size,
|
| 401 |
+
activation_function="swiglu", # Hardcode since HF calls it 'silu'
|
| 402 |
+
# Llama doesn't have dropout, idk if it's because they only release the inference code
|
| 403 |
+
resid_pdrop=0.0,
|
| 404 |
+
embd_pdrop=0.0,
|
| 405 |
+
attn_pdrop=0.0,
|
| 406 |
+
layer_norm_epsilon=llama_config.rms_norm_eps,
|
| 407 |
+
initializer_range=llama_config.initializer_range,
|
| 408 |
+
bos_token_id=llama_config.bos_token_id,
|
| 409 |
+
eos_token_id=llama_config.eos_token_id,
|
| 410 |
+
# These are new arguments not in the original GPT2Config
|
| 411 |
+
pad_token_id=llama_config.pad_token_id, # Idk if this does anything
|
| 412 |
+
rms_norm=True,
|
| 413 |
+
rotary_emb_fraction=1.0,
|
| 414 |
+
rotary_emb_interleaved=True,
|
| 415 |
+
tie_word_embeddings=False,
|
| 416 |
+
qkv_proj_bias=False,
|
| 417 |
+
out_proj_bias=False,
|
| 418 |
+
mlp_fc1_bias=False,
|
| 419 |
+
mlp_fc2_bias=False,
|
| 420 |
+
rotary_emb_base=getattr(llama_config, "rotary_emb_base", 10000.0),
|
| 421 |
+
n_head_kv=llama_config.num_key_value_heads,
|
| 422 |
+
)
|
.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/opt.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023, Tri Dao.
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
import re
|
| 5 |
+
from collections import OrderedDict
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from transformers import GPT2Config, OPTConfig
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def remap_state_dict_hf_opt(state_dict, config):
|
| 13 |
+
def key_mapping_model(key):
|
| 14 |
+
key = re.sub(r"^model.decoder.", "transformer.", key)
|
| 15 |
+
# The OPT-350m model uses '^decoder' instead of '^model.decoder'
|
| 16 |
+
key = re.sub(r"^decoder.", "transformer.", key)
|
| 17 |
+
return key
|
| 18 |
+
|
| 19 |
+
state_dict = OrderedDict((key_mapping_model(k), v) for k, v in state_dict.items())
|
| 20 |
+
# Word embedding and position embedding
|
| 21 |
+
def key_mapping_emb(key):
|
| 22 |
+
key = re.sub(r"^transformer.embed_tokens.", "transformer.embeddings.word_embeddings.", key)
|
| 23 |
+
# The OPT-350m model uses has project_in and project_out
|
| 24 |
+
key = re.sub(r"^transformer.project_in.", "transformer.embeddings.project_in.", key)
|
| 25 |
+
key = re.sub(r"^transformer.project_out.", "project_out.", key)
|
| 26 |
+
key = re.sub(
|
| 27 |
+
r"^transformer.embed_positions.", "transformer.embeddings.position_embeddings.", key
|
| 28 |
+
)
|
| 29 |
+
return key
|
| 30 |
+
|
| 31 |
+
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
|
| 32 |
+
# OPT uses the first 2 indices of pos_emb for padding tokens
|
| 33 |
+
pos_embeddings = state_dict.pop("transformer.embeddings.position_embeddings.weight")
|
| 34 |
+
state_dict["transformer.embeddings.position_embeddings.weight"] = pos_embeddings[2:]
|
| 35 |
+
word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight")
|
| 36 |
+
# It's possible that vocab_size is padded to be a multiple of 8, for example.
|
| 37 |
+
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
| 38 |
+
vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
|
| 39 |
+
state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
|
| 40 |
+
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
|
| 41 |
+
)
|
| 42 |
+
state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
|
| 43 |
+
|
| 44 |
+
# LayerNorm
|
| 45 |
+
def key_mapping_ln(key):
|
| 46 |
+
key = re.sub(r"^transformer.final_layer_norm.", r"transformer.ln_f.", key)
|
| 47 |
+
# The OPT-175B checkpoint calls this 'decoder.layer_norm' instead of 'decoder.final_layer_norm'
|
| 48 |
+
key = re.sub(r"^transformer.layer_norm.", r"transformer.ln_f.", key)
|
| 49 |
+
key = re.sub(
|
| 50 |
+
r"^transformer.layers.(\d+).self_attn_layer_norm.", r"transformer.layers.\1.norm1.", key
|
| 51 |
+
)
|
| 52 |
+
key = re.sub(
|
| 53 |
+
r"^transformer.layers.(\d+).final_layer_norm.", r"transformer.layers.\1.norm2.", key
|
| 54 |
+
)
|
| 55 |
+
return key
|
| 56 |
+
|
| 57 |
+
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
|
| 58 |
+
|
| 59 |
+
# MLP
|
| 60 |
+
def key_mapping_mlp(key):
|
| 61 |
+
return re.sub(
|
| 62 |
+
r"^transformer.layers.(\d+).fc(1|2).", r"transformer.layers.\1.mlp.fc\2.", key
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
|
| 66 |
+
|
| 67 |
+
# Attention
|
| 68 |
+
for l in range(config.n_layer):
|
| 69 |
+
Wq = state_dict.pop(f"transformer.layers.{l}.self_attn.q_proj.weight")
|
| 70 |
+
Wk = state_dict.pop(f"transformer.layers.{l}.self_attn.k_proj.weight")
|
| 71 |
+
Wv = state_dict.pop(f"transformer.layers.{l}.self_attn.v_proj.weight")
|
| 72 |
+
bq = state_dict.pop(f"transformer.layers.{l}.self_attn.q_proj.bias")
|
| 73 |
+
bk = state_dict.pop(f"transformer.layers.{l}.self_attn.k_proj.bias")
|
| 74 |
+
bv = state_dict.pop(f"transformer.layers.{l}.self_attn.v_proj.bias")
|
| 75 |
+
state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat([Wq, Wk, Wv], dim=0)
|
| 76 |
+
state_dict[f"transformer.layers.{l}.mixer.Wqkv.bias"] = torch.cat([bq, bk, bv], dim=0)
|
| 77 |
+
|
| 78 |
+
def key_mapping_attn(key):
|
| 79 |
+
return re.sub(
|
| 80 |
+
r"^transformer.layers.(\d+).self_attn.out_proj.",
|
| 81 |
+
r"transformer.layers.\1.mixer.out_proj.",
|
| 82 |
+
key,
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
|
| 86 |
+
|
| 87 |
+
return state_dict
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def opt_config_to_gpt2_config(opt_config: OPTConfig) -> GPT2Config:
|
| 91 |
+
assert opt_config.layerdrop == 0.0
|
| 92 |
+
assert opt_config.layer_norm_elementwise_affine
|
| 93 |
+
word_embed_proj_dim = (
|
| 94 |
+
None
|
| 95 |
+
if opt_config.word_embed_proj_dim == opt_config.hidden_size
|
| 96 |
+
else opt_config.word_embed_proj_dim
|
| 97 |
+
)
|
| 98 |
+
return GPT2Config(
|
| 99 |
+
vocab_size=opt_config.vocab_size,
|
| 100 |
+
n_positions=opt_config.max_position_embeddings,
|
| 101 |
+
n_embd=opt_config.hidden_size,
|
| 102 |
+
n_layer=opt_config.num_hidden_layers,
|
| 103 |
+
n_head=opt_config.num_attention_heads,
|
| 104 |
+
n_inner=opt_config.ffn_dim,
|
| 105 |
+
activation_function=opt_config.activation_function,
|
| 106 |
+
resid_pdrop=opt_config.dropout,
|
| 107 |
+
# HF's implementation of OPT doesn't seem to have embedding dropout
|
| 108 |
+
embd_pdrop=opt_config.dropout,
|
| 109 |
+
attn_pdrop=opt_config.attention_dropout,
|
| 110 |
+
initializer_range=opt_config.init_std,
|
| 111 |
+
bos_token_id=opt_config.bos_token_id,
|
| 112 |
+
eos_token_id=opt_config.eos_token_id,
|
| 113 |
+
# These are new arguments not in the original GPT2Config
|
| 114 |
+
prenorm=opt_config.do_layer_norm_before,
|
| 115 |
+
word_embed_proj_dim=word_embed_proj_dim,
|
| 116 |
+
)
|
.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/vit.py
ADDED
|
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022, Tri Dao.
|
| 2 |
+
# Inspired by / adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
| 3 |
+
import math
|
| 4 |
+
import re
|
| 5 |
+
from collections import OrderedDict
|
| 6 |
+
from copy import deepcopy
|
| 7 |
+
from functools import partial
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from einops import rearrange
|
| 13 |
+
from timm.models.helpers import named_apply
|
| 14 |
+
from torch.nn.init import trunc_normal_
|
| 15 |
+
from torchvision.ops import StochasticDepth
|
| 16 |
+
|
| 17 |
+
from flash_attn.layers.patch_embed import PatchEmbed
|
| 18 |
+
from flash_attn.modules.block import Block
|
| 19 |
+
from flash_attn.modules.mha import MHA
|
| 20 |
+
from flash_attn.modules.mlp import FusedMLP, Mlp
|
| 21 |
+
|
| 22 |
+
try:
|
| 23 |
+
from flash_attn.ops.triton.layer_norm import layer_norm_fn
|
| 24 |
+
except ImportError:
|
| 25 |
+
layer_norm_fn = None
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def create_mixer_cls(
|
| 29 |
+
num_heads, qkv_bias, attn_drop, use_flash_attn, fused_bias_fc, cross_attn=False
|
| 30 |
+
):
|
| 31 |
+
mixer_cls = partial(
|
| 32 |
+
MHA,
|
| 33 |
+
num_heads=num_heads,
|
| 34 |
+
cross_attn=cross_attn,
|
| 35 |
+
qkv_proj_bias=qkv_bias,
|
| 36 |
+
dropout=attn_drop,
|
| 37 |
+
fused_bias_fc=fused_bias_fc,
|
| 38 |
+
use_flash_attn=use_flash_attn,
|
| 39 |
+
)
|
| 40 |
+
return mixer_cls
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_mlp):
|
| 44 |
+
inner_dim = int(embed_dim * mlp_ratio)
|
| 45 |
+
if not fused_mlp:
|
| 46 |
+
mlp_cls = partial(Mlp, hidden_features=inner_dim, activation=act_layer())
|
| 47 |
+
else:
|
| 48 |
+
mlp_cls = partial(FusedMLP, hidden_features=inner_dim)
|
| 49 |
+
return mlp_cls
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def create_block(
|
| 53 |
+
embed_dim,
|
| 54 |
+
num_heads,
|
| 55 |
+
mlp_ratio,
|
| 56 |
+
qkv_bias,
|
| 57 |
+
drop_rate,
|
| 58 |
+
attn_drop_rate,
|
| 59 |
+
drop_path1,
|
| 60 |
+
drop_path2,
|
| 61 |
+
norm_layer,
|
| 62 |
+
act_layer,
|
| 63 |
+
use_flash_attn,
|
| 64 |
+
fused_bias_fc,
|
| 65 |
+
fused_mlp,
|
| 66 |
+
fused_dropout_add_ln,
|
| 67 |
+
layer_idx=None,
|
| 68 |
+
n_layer=None,
|
| 69 |
+
last_layer_subset=False,
|
| 70 |
+
):
|
| 71 |
+
mixer_cls = create_mixer_cls(
|
| 72 |
+
num_heads,
|
| 73 |
+
qkv_bias,
|
| 74 |
+
attn_drop_rate,
|
| 75 |
+
use_flash_attn,
|
| 76 |
+
fused_bias_fc,
|
| 77 |
+
cross_attn=(last_layer_subset and layer_idx == n_layer - 1),
|
| 78 |
+
)
|
| 79 |
+
mlp_cls = create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_mlp)
|
| 80 |
+
# TD [2022-10-15]: Force residual in fp32 in case of DeepSpeed
|
| 81 |
+
block = Block(
|
| 82 |
+
embed_dim,
|
| 83 |
+
mixer_cls,
|
| 84 |
+
mlp_cls,
|
| 85 |
+
norm_cls=norm_layer,
|
| 86 |
+
prenorm=True,
|
| 87 |
+
resid_dropout1=drop_rate,
|
| 88 |
+
resid_dropout2=drop_rate,
|
| 89 |
+
drop_path1=drop_path1,
|
| 90 |
+
drop_path2=drop_path2,
|
| 91 |
+
fused_dropout_add_ln=fused_dropout_add_ln,
|
| 92 |
+
residual_in_fp32=True,
|
| 93 |
+
)
|
| 94 |
+
return block
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class VisionTransformer(nn.Module):
|
| 98 |
+
"""Vision Transformer
|
| 99 |
+
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
|
| 100 |
+
- https://arxiv.org/abs/2010.11929
|
| 101 |
+
"""
|
| 102 |
+
|
| 103 |
+
def __init__(
|
| 104 |
+
self,
|
| 105 |
+
img_size=224,
|
| 106 |
+
patch_size=16,
|
| 107 |
+
in_chans=3,
|
| 108 |
+
num_classes=1000,
|
| 109 |
+
global_pool="token",
|
| 110 |
+
embed_dim=768,
|
| 111 |
+
depth=12,
|
| 112 |
+
num_heads=12,
|
| 113 |
+
mlp_ratio=4.0,
|
| 114 |
+
qkv_bias=True,
|
| 115 |
+
init_values=None,
|
| 116 |
+
class_token=True,
|
| 117 |
+
no_embed_class=False,
|
| 118 |
+
pre_norm=False,
|
| 119 |
+
fc_norm=None,
|
| 120 |
+
drop_rate=0.0,
|
| 121 |
+
attn_drop_rate=0.0,
|
| 122 |
+
drop_path_rate=0.0,
|
| 123 |
+
weight_init="",
|
| 124 |
+
embed_layer=PatchEmbed,
|
| 125 |
+
norm_layer=None,
|
| 126 |
+
act_layer=None,
|
| 127 |
+
use_flash_attn=False,
|
| 128 |
+
fused_bias_fc=False,
|
| 129 |
+
fused_mlp=False,
|
| 130 |
+
fused_dropout_add_ln=False,
|
| 131 |
+
):
|
| 132 |
+
"""
|
| 133 |
+
Args:
|
| 134 |
+
img_size (int, tuple): input image size
|
| 135 |
+
patch_size (int, tuple): patch size
|
| 136 |
+
in_chans (int): number of input channels
|
| 137 |
+
num_classes (int): number of classes for classification head
|
| 138 |
+
global_pool (str): type of global pooling for final sequence (default: 'token')
|
| 139 |
+
embed_dim (int): embedding dimension
|
| 140 |
+
depth (int): depth of transformer
|
| 141 |
+
num_heads (int): number of attention heads
|
| 142 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
| 143 |
+
qkv_bias (bool): enable bias for qkv if True
|
| 144 |
+
init_values: (float): layer-scale init values
|
| 145 |
+
class_token (bool): use class token
|
| 146 |
+
fc_norm (Optional[bool]): pre-fc norm after pool, set if global_pool == 'avg' if None (default: None)
|
| 147 |
+
drop_rate (float): dropout rate
|
| 148 |
+
attn_drop_rate (float): attention dropout rate
|
| 149 |
+
drop_path_rate (float): stochastic depth rate
|
| 150 |
+
weight_init (str): weight init scheme
|
| 151 |
+
embed_layer (nn.Module): patch embedding layer
|
| 152 |
+
norm_layer: (nn.Module): normalization layer
|
| 153 |
+
act_layer: (nn.Module): MLP activation layer
|
| 154 |
+
"""
|
| 155 |
+
super().__init__()
|
| 156 |
+
assert global_pool == "token", "Only support pooling with CLS token"
|
| 157 |
+
assert class_token
|
| 158 |
+
assert init_values is None, "LayerScale is not supported yet"
|
| 159 |
+
assert weight_init == ""
|
| 160 |
+
assert fc_norm is None
|
| 161 |
+
# pre_norm seems redundant, as there's a LayerNorm right at the start of each block, idk
|
| 162 |
+
assert not pre_norm
|
| 163 |
+
use_fc_norm = global_pool == "avg" if fc_norm is None else fc_norm
|
| 164 |
+
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
| 165 |
+
act_layer = act_layer or nn.GELU
|
| 166 |
+
|
| 167 |
+
self.num_classes = num_classes
|
| 168 |
+
self.global_pool = global_pool
|
| 169 |
+
self.num_features = (
|
| 170 |
+
self.embed_dim
|
| 171 |
+
) = embed_dim # num_features for consistency with other models
|
| 172 |
+
self.num_prefix_tokens = 1 if class_token else 0
|
| 173 |
+
self.no_embed_class = no_embed_class
|
| 174 |
+
|
| 175 |
+
patch_embed_extra_kwargs = (
|
| 176 |
+
{"fused_bias_fc": fused_bias_fc} if embed_layer is PatchEmbed else {}
|
| 177 |
+
)
|
| 178 |
+
self.patch_embed = embed_layer(
|
| 179 |
+
img_size=img_size,
|
| 180 |
+
patch_size=patch_size,
|
| 181 |
+
in_chans=in_chans,
|
| 182 |
+
embed_dim=embed_dim,
|
| 183 |
+
bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
|
| 184 |
+
**patch_embed_extra_kwargs,
|
| 185 |
+
)
|
| 186 |
+
num_patches = self.patch_embed.num_patches
|
| 187 |
+
|
| 188 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
|
| 189 |
+
embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
|
| 190 |
+
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * 0.02)
|
| 191 |
+
|
| 192 |
+
dpr = [
|
| 193 |
+
x.item() for x in torch.linspace(0, drop_path_rate, depth)
|
| 194 |
+
] # stochastic depth decay rule
|
| 195 |
+
|
| 196 |
+
# We change the order of dropout, residual and layer norm:
|
| 197 |
+
# Instead of LN -> Attn / MLP -> Dropout -> Add, we do:
|
| 198 |
+
# Dropout -> Add -> LN -> Attn / MLP, returning both the residual branch (output of Add) and
|
| 199 |
+
# the main branch (output of MLP). The model definition is unchanged, but the mapping of the
|
| 200 |
+
# nn.Dropout probabilities are changed.
|
| 201 |
+
# This is for performance reason: we can fuse dropout + add + layer_norm.
|
| 202 |
+
self.blocks = nn.ModuleList(
|
| 203 |
+
[
|
| 204 |
+
create_block(
|
| 205 |
+
embed_dim,
|
| 206 |
+
num_heads,
|
| 207 |
+
mlp_ratio,
|
| 208 |
+
qkv_bias,
|
| 209 |
+
drop_rate,
|
| 210 |
+
attn_drop_rate,
|
| 211 |
+
drop_path1=dpr[i - 1] if i > 0 else 0.0,
|
| 212 |
+
drop_path2=dpr[i],
|
| 213 |
+
norm_layer=norm_layer,
|
| 214 |
+
act_layer=act_layer,
|
| 215 |
+
use_flash_attn=use_flash_attn,
|
| 216 |
+
fused_bias_fc=fused_bias_fc,
|
| 217 |
+
fused_mlp=fused_mlp,
|
| 218 |
+
fused_dropout_add_ln=fused_dropout_add_ln,
|
| 219 |
+
layer_idx=i,
|
| 220 |
+
n_layer=depth,
|
| 221 |
+
last_layer_subset=(global_pool == "token"),
|
| 222 |
+
)
|
| 223 |
+
for i in range(depth)
|
| 224 |
+
]
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
self.dropout = nn.Dropout(p=drop_rate)
|
| 228 |
+
self.drop_path = StochasticDepth(p=dpr[-1], mode="row")
|
| 229 |
+
self.norm = norm_layer(embed_dim)
|
| 230 |
+
|
| 231 |
+
self.fused_dropout_add_ln = fused_dropout_add_ln
|
| 232 |
+
if self.fused_dropout_add_ln and layer_norm_fn is None:
|
| 233 |
+
raise ImportError("Triton is not installed")
|
| 234 |
+
|
| 235 |
+
# Classifier Head
|
| 236 |
+
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
| 237 |
+
|
| 238 |
+
self.init_weights(weight_init)
|
| 239 |
+
|
| 240 |
+
def init_weights(self, mode=""):
|
| 241 |
+
assert mode == ""
|
| 242 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
| 243 |
+
if self.cls_token is not None:
|
| 244 |
+
nn.init.normal_(self.cls_token, std=1e-6)
|
| 245 |
+
named_apply(init_weights_vit_timm, self)
|
| 246 |
+
|
| 247 |
+
def _init_weights(self, m):
|
| 248 |
+
# this fn left here for compat with downstream users
|
| 249 |
+
init_weights_vit_timm(m)
|
| 250 |
+
|
| 251 |
+
@torch.jit.ignore
|
| 252 |
+
def no_weight_decay(self):
|
| 253 |
+
return {"pos_embed", "cls_token"}
|
| 254 |
+
|
| 255 |
+
def _pos_embed(self, x):
|
| 256 |
+
if self.no_embed_class:
|
| 257 |
+
# deit-3, updated JAX (big vision)
|
| 258 |
+
# position embedding does not overlap with class token, add then concat
|
| 259 |
+
x = x + self.pos_embed
|
| 260 |
+
if self.cls_token is not None:
|
| 261 |
+
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
| 262 |
+
else:
|
| 263 |
+
# original timm, JAX, and deit vit impl
|
| 264 |
+
# pos_embed has entry for class token, concat then add
|
| 265 |
+
if self.cls_token is not None:
|
| 266 |
+
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
| 267 |
+
x = x + self.pos_embed
|
| 268 |
+
return x
|
| 269 |
+
|
| 270 |
+
def forward_features(self, x, all_tokens=True):
|
| 271 |
+
"""
|
| 272 |
+
If all_tokens==False and self.global_pool == 'token', we only return the features for the
|
| 273 |
+
cls token.
|
| 274 |
+
"""
|
| 275 |
+
x = self.patch_embed(x)
|
| 276 |
+
hidden_states = self._pos_embed(x)
|
| 277 |
+
residual = None
|
| 278 |
+
if self.global_pool != "token" or all_tokens:
|
| 279 |
+
# if True:
|
| 280 |
+
for block in self.blocks:
|
| 281 |
+
hidden_states, residual = block(hidden_states, residual)
|
| 282 |
+
else:
|
| 283 |
+
for block in self.blocks[:-1]:
|
| 284 |
+
hidden_states, residual = block(hidden_states, residual)
|
| 285 |
+
# For the last layer, we only want the 1st token of the output. So we do cross-attention
|
| 286 |
+
# where the query is the 1st token and the key/value is the whole sequence.
|
| 287 |
+
hidden_states, residual = self.blocks[-1](
|
| 288 |
+
hidden_states, residual, mixer_subset=slice(0, 1)
|
| 289 |
+
)
|
| 290 |
+
if not self.fused_dropout_add_ln:
|
| 291 |
+
residual = self.drop_path(self.dropout(hidden_states)) + residual
|
| 292 |
+
hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
|
| 293 |
+
else:
|
| 294 |
+
if self.drop_path.p == 0 or not self.training:
|
| 295 |
+
rowscale = None
|
| 296 |
+
else:
|
| 297 |
+
rowscale = self.drop_path(
|
| 298 |
+
torch.ones(
|
| 299 |
+
hidden_states.shape[:-1],
|
| 300 |
+
device=hidden_states.device,
|
| 301 |
+
dtype=hidden_states.dtype,
|
| 302 |
+
)
|
| 303 |
+
)
|
| 304 |
+
# Set prenorm=False here since we don't need to the residual
|
| 305 |
+
hidden_states = layer_norm_fn(
|
| 306 |
+
hidden_states,
|
| 307 |
+
self.norm.weight,
|
| 308 |
+
self.norm.bias,
|
| 309 |
+
residual=residual,
|
| 310 |
+
eps=self.norm.eps,
|
| 311 |
+
dropout_p=self.dropout.p if self.training else 0.0,
|
| 312 |
+
rowscale=rowscale,
|
| 313 |
+
prenorm=False,
|
| 314 |
+
)
|
| 315 |
+
return hidden_states
|
| 316 |
+
|
| 317 |
+
def forward_head(self, x, pre_logits: bool = False):
|
| 318 |
+
if self.global_pool:
|
| 319 |
+
x = x[:, self.num_prefix_tokens :].mean(dim=1) if self.global_pool == "avg" else x[:, 0]
|
| 320 |
+
return x if pre_logits else self.head(x)
|
| 321 |
+
|
| 322 |
+
def forward(self, x):
|
| 323 |
+
x = self.forward_features(x, all_tokens=False)
|
| 324 |
+
x = self.forward_head(x)
|
| 325 |
+
return x
|
| 326 |
+
|
| 327 |
+
def load_state_dict(self, state_dict, strict=True):
|
| 328 |
+
patch_embed_weight = state_dict["patch_embed.proj.weight"]
|
| 329 |
+
if patch_embed_weight.dim() == 4:
|
| 330 |
+
# convert from Conv2d to Linear
|
| 331 |
+
state_dict["patch_embed.proj.weight"] = rearrange(
|
| 332 |
+
patch_embed_weight, "o c h w -> o (c h w)"
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
def key_mapping_attn(key):
|
| 336 |
+
key = re.sub(r"^blocks.(\d+).attn.qkv.", r"blocks.\1.mixer.Wqkv.", key)
|
| 337 |
+
key = re.sub(r"^blocks.(\d+).attn.proj.", r"blocks.\1.mixer.out_proj.", key)
|
| 338 |
+
return key
|
| 339 |
+
|
| 340 |
+
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
|
| 341 |
+
n_layer = len(self.blocks)
|
| 342 |
+
# Convert from Wqkv to Wq and Wkv for cross attention (last layer)
|
| 343 |
+
if (
|
| 344 |
+
self.blocks[-1].mixer.cross_attn
|
| 345 |
+
and f"blocks.{n_layer - 1}.mixer.Wqkv.weight" in state_dict
|
| 346 |
+
):
|
| 347 |
+
Wqkv = state_dict.pop(f"blocks.{n_layer - 1}.mixer.Wqkv.weight")
|
| 348 |
+
bqkv = state_dict.pop(f"blocks.{n_layer - 1}.mixer.Wqkv.bias")
|
| 349 |
+
state_dict[f"blocks.{n_layer - 1}.mixer.Wq.weight"] = Wqkv[: self.embed_dim]
|
| 350 |
+
state_dict[f"blocks.{n_layer - 1}.mixer.Wkv.weight"] = Wqkv[self.embed_dim :]
|
| 351 |
+
state_dict[f"blocks.{n_layer - 1}.mixer.Wq.bias"] = bqkv[: self.embed_dim]
|
| 352 |
+
state_dict[f"blocks.{n_layer - 1}.mixer.Wkv.bias"] = bqkv[self.embed_dim :]
|
| 353 |
+
return super().load_state_dict(state_dict, strict=strict)
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
def init_weights_vit_timm(module: nn.Module, name: str = ""):
|
| 357 |
+
"""ViT weight initialization, original timm impl (for reproducibility)"""
|
| 358 |
+
if isinstance(module, nn.Linear):
|
| 359 |
+
trunc_normal_(module.weight, std=0.02)
|
| 360 |
+
if module.bias is not None:
|
| 361 |
+
nn.init.zeros_(module.bias)
|
| 362 |
+
elif hasattr(module, "init_weights"):
|
| 363 |
+
module.init_weights()
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
def vit_base_patch16_224(pretrained=False, **kwargs):
|
| 367 |
+
"""ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
|
| 368 |
+
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
|
| 369 |
+
"""
|
| 370 |
+
assert not pretrained
|
| 371 |
+
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
| 372 |
+
model = VisionTransformer(**model_kwargs)
|
| 373 |
+
return model
|
.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (197 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/__pycache__/activations.cpython-311.pyc
ADDED
|
Binary file (6.86 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/__pycache__/fused_dense.cpython-311.pyc
ADDED
|
Binary file (30.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/__pycache__/layer_norm.cpython-311.pyc
ADDED
|
Binary file (22.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/__pycache__/rms_norm.cpython-311.pyc
ADDED
|
Binary file (5.19 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/activations.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copied from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/model/layers/activations.py
|
| 2 |
+
import math
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
# 1/sqrt(2*pi)-> 0.3989423
|
| 9 |
+
# 1/sqrt(2) -> 0.70710678
|
| 10 |
+
# sqrt(2/pi) -> 0.79788456
|
| 11 |
+
|
| 12 |
+
# this function is tanh approximation of gelu
|
| 13 |
+
# actual gelu is:
|
| 14 |
+
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
|
| 15 |
+
@torch.jit.script
|
| 16 |
+
def bias_gelu(y, bias):
|
| 17 |
+
x = bias + y
|
| 18 |
+
return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=y.dtype)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# gradient of tanh approximation of gelu
|
| 22 |
+
# gradient of actual gelu is:
|
| 23 |
+
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
|
| 24 |
+
@torch.jit.script
|
| 25 |
+
def bias_gelu_back(g, y, bias):
|
| 26 |
+
"""Assume that y has shape (B, D) and bias has shape (D)"""
|
| 27 |
+
x = bias + y
|
| 28 |
+
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
|
| 29 |
+
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
|
| 30 |
+
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (
|
| 31 |
+
1 + tanh_out
|
| 32 |
+
)
|
| 33 |
+
grad_y = ff * g
|
| 34 |
+
return grad_y.to(dtype=y.dtype), grad_y.sum(dim=(0), dtype=bias.dtype)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class GeLUFunction(torch.autograd.Function):
|
| 38 |
+
@staticmethod
|
| 39 |
+
# bias is an optional argument
|
| 40 |
+
def forward(ctx, input, bias):
|
| 41 |
+
ctx.save_for_backward(input, bias)
|
| 42 |
+
return bias_gelu(input, bias)
|
| 43 |
+
|
| 44 |
+
@staticmethod
|
| 45 |
+
def backward(ctx, grad_output):
|
| 46 |
+
input, bias = ctx.saved_tensors
|
| 47 |
+
tmp = bias_gelu_back(grad_output, input, bias)
|
| 48 |
+
return tmp, tmp
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
bias_gelu_impl = GeLUFunction.apply
|
| 52 |
+
|
| 53 |
+
# this function is tanh approximation of gelu
|
| 54 |
+
# actual gelu is:
|
| 55 |
+
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
|
| 56 |
+
@torch.jit.script
|
| 57 |
+
def gelu_fwd(x):
|
| 58 |
+
return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=x.dtype)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# gradient of tanh approximation of gelu
|
| 62 |
+
# gradient of actual gelu is:
|
| 63 |
+
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
|
| 64 |
+
@torch.jit.script
|
| 65 |
+
def gelu_bwd(g, x):
|
| 66 |
+
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
|
| 67 |
+
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
|
| 68 |
+
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (
|
| 69 |
+
1 + tanh_out
|
| 70 |
+
)
|
| 71 |
+
return (ff * g).to(dtype=x.dtype)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class FastGeLUFunction(torch.autograd.Function):
|
| 75 |
+
@staticmethod
|
| 76 |
+
# bias is an optional argument
|
| 77 |
+
def forward(ctx, input):
|
| 78 |
+
ctx.save_for_backward(input)
|
| 79 |
+
return gelu_fwd(input)
|
| 80 |
+
|
| 81 |
+
@staticmethod
|
| 82 |
+
def backward(ctx, grad_output):
|
| 83 |
+
(input,) = ctx.saved_tensors
|
| 84 |
+
tmp = gelu_bwd(grad_output, input)
|
| 85 |
+
return tmp
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
fast_gelu_impl = FastGeLUFunction.apply
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
@torch.jit.script
|
| 92 |
+
def relu_bwd(g, x):
|
| 93 |
+
return torch.where(x >= 0, g, 0.0).to(dtype=x.dtype)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
@torch.jit.script
|
| 97 |
+
def sqrelu_fwd(x):
|
| 98 |
+
r = F.relu(x)
|
| 99 |
+
return (r * r).to(dtype=x.dtype)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
@torch.jit.script
|
| 103 |
+
def sqrelu_bwd(g, x):
|
| 104 |
+
return (2.0 * g * F.relu(x)).to(dtype=x.dtype)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
swiglu_fwd_codestring = """
|
| 108 |
+
template <typename T> T swiglu_fwd(T x, T y) {
|
| 109 |
+
return float(x) * float(y) / (1.0f + ::exp(-float(x)));
|
| 110 |
+
}
|
| 111 |
+
"""
|
| 112 |
+
swiglu_bwd_codestring = """
|
| 113 |
+
template <typename T> T swiglu_bwd(T x, T y, T g, T& dx, T& dy) {
|
| 114 |
+
float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x)));
|
| 115 |
+
dx = x_sigmoid * (1 + float(x) * (1.0f - x_sigmoid)) * float(g) * float(y);
|
| 116 |
+
dy = float(x) * x_sigmoid * float(g);
|
| 117 |
+
}
|
| 118 |
+
"""
|
| 119 |
+
swiglu_fwd = torch.cuda.jiterator._create_jit_fn(swiglu_fwd_codestring)
|
| 120 |
+
swiglu_bwd = torch.cuda.jiterator._create_multi_output_jit_fn(swiglu_bwd_codestring, num_outputs=2)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class SwiGLUFunction(torch.autograd.Function):
|
| 124 |
+
|
| 125 |
+
@staticmethod
|
| 126 |
+
def forward(ctx, x, y):
|
| 127 |
+
ctx.save_for_backward(x, y)
|
| 128 |
+
return swiglu_fwd(x, y)
|
| 129 |
+
|
| 130 |
+
@staticmethod
|
| 131 |
+
def backward(ctx, dout):
|
| 132 |
+
x, y = ctx.saved_tensors
|
| 133 |
+
return swiglu_bwd(x, y, dout)
|
| 134 |
+
|
| 135 |
+
swiglu = SwiGLUFunction.apply
|
.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/fused_dense.py
ADDED
|
@@ -0,0 +1,688 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023, Tri Dao.
|
| 2 |
+
# Inspired by https://github.com/NVIDIA/apex/blob/master/apex/fused_dense/fused_dense.py
|
| 3 |
+
# We make it work with pytorch amp and with bfloat16.
|
| 4 |
+
# The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py
|
| 5 |
+
from functools import partial
|
| 6 |
+
from typing import Optional
|
| 7 |
+
|
| 8 |
+
# import fused_dense_cuda # from apex
|
| 9 |
+
import fused_dense_lib as fused_dense_cuda
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from torch import Tensor
|
| 14 |
+
from torch.cuda.amp import custom_bwd, custom_fwd
|
| 15 |
+
from torch.distributed import ProcessGroup
|
| 16 |
+
|
| 17 |
+
from flash_attn.ops.activations import gelu_bwd, relu_bwd, sqrelu_bwd, sqrelu_fwd
|
| 18 |
+
from flash_attn.utils.distributed import (
|
| 19 |
+
all_gather_raw,
|
| 20 |
+
all_reduce,
|
| 21 |
+
all_reduce_raw,
|
| 22 |
+
reduce_scatter,
|
| 23 |
+
reduce_scatter_raw,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class FusedDenseFunc(torch.autograd.Function):
|
| 28 |
+
@staticmethod
|
| 29 |
+
@custom_fwd
|
| 30 |
+
def forward(
|
| 31 |
+
ctx, x, weight, bias, return_residual=False, process_group=None, sequence_parallel=True
|
| 32 |
+
):
|
| 33 |
+
"""
|
| 34 |
+
If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
|
| 35 |
+
with sequence parallelism: we do an all_gather_raw of x before doing the matmul.
|
| 36 |
+
"""
|
| 37 |
+
ctx.compute_weight_gradient = weight.requires_grad
|
| 38 |
+
ctx.return_residual = return_residual
|
| 39 |
+
ctx.process_group = process_group
|
| 40 |
+
ctx.sequence_parallel = sequence_parallel
|
| 41 |
+
|
| 42 |
+
if torch.is_autocast_enabled():
|
| 43 |
+
x = x.to(dtype=torch.get_autocast_gpu_dtype())
|
| 44 |
+
x = x.contiguous()
|
| 45 |
+
if process_group is not None and sequence_parallel:
|
| 46 |
+
# We want to kick off the all_gather early, before weight dtype conversion
|
| 47 |
+
total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
|
| 48 |
+
else:
|
| 49 |
+
total_x = x
|
| 50 |
+
|
| 51 |
+
if torch.is_autocast_enabled():
|
| 52 |
+
weight = weight.to(dtype=torch.get_autocast_gpu_dtype())
|
| 53 |
+
bias = bias.to(dtype=torch.get_autocast_gpu_dtype()) if bias is not None else None
|
| 54 |
+
weight = weight.contiguous()
|
| 55 |
+
if process_group is not None and sequence_parallel:
|
| 56 |
+
handle_x.wait()
|
| 57 |
+
batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
|
| 58 |
+
batch_dim = batch_shape.numel()
|
| 59 |
+
# https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
|
| 60 |
+
if min(batch_dim, n, *weight.shape) > 65535 * 32:
|
| 61 |
+
raise RuntimeError("fused_dense only supports matrix dims <= 2M")
|
| 62 |
+
output = F.linear(total_x, weight, bias)
|
| 63 |
+
if ctx.compute_weight_gradient:
|
| 64 |
+
ctx.save_for_backward(x, weight)
|
| 65 |
+
else:
|
| 66 |
+
ctx.save_for_backward(weight)
|
| 67 |
+
return output if not return_residual else (output, x)
|
| 68 |
+
|
| 69 |
+
@staticmethod
|
| 70 |
+
@custom_bwd
|
| 71 |
+
def backward(ctx, grad_output, *args):
|
| 72 |
+
grad_output = grad_output.contiguous()
|
| 73 |
+
if ctx.return_residual:
|
| 74 |
+
(grad_input,) = args
|
| 75 |
+
grad_input = grad_input.contiguous()
|
| 76 |
+
process_group = ctx.process_group
|
| 77 |
+
sequence_parallel = ctx.sequence_parallel
|
| 78 |
+
if ctx.compute_weight_gradient:
|
| 79 |
+
x, weight = ctx.saved_tensors
|
| 80 |
+
if process_group is not None and sequence_parallel:
|
| 81 |
+
total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
|
| 82 |
+
else:
|
| 83 |
+
total_x = x
|
| 84 |
+
else:
|
| 85 |
+
(weight,) = ctx.saved_tensors
|
| 86 |
+
total_x = None
|
| 87 |
+
batch_shape = grad_output.shape[:-1]
|
| 88 |
+
batch_dim = batch_shape.numel()
|
| 89 |
+
grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
|
| 90 |
+
if ctx.needs_input_grad[0]:
|
| 91 |
+
if not ctx.return_residual:
|
| 92 |
+
grad_input = F.linear(grad_output, weight.t())
|
| 93 |
+
else:
|
| 94 |
+
grad_input = torch.addmm(
|
| 95 |
+
grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_output, weight
|
| 96 |
+
)
|
| 97 |
+
grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
|
| 98 |
+
if process_group is not None:
|
| 99 |
+
reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
|
| 100 |
+
grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True)
|
| 101 |
+
else:
|
| 102 |
+
grad_input = None
|
| 103 |
+
if ctx.needs_input_grad[1]:
|
| 104 |
+
assert ctx.compute_weight_gradient
|
| 105 |
+
if process_group is not None and sequence_parallel:
|
| 106 |
+
handle_x.wait()
|
| 107 |
+
grad_weight, grad_bias = fused_dense_cuda.linear_bias_wgrad(
|
| 108 |
+
total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2]
|
| 109 |
+
)
|
| 110 |
+
else:
|
| 111 |
+
grad_weight = None
|
| 112 |
+
grad_bias = grad_output if ctx.needs_input_grad[2] else None
|
| 113 |
+
if process_group is not None and ctx.needs_input_grad[0]:
|
| 114 |
+
handle_grad_input.wait()
|
| 115 |
+
return grad_input, grad_weight, grad_bias, None, None, None
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def fused_dense_func(
|
| 119 |
+
x: Tensor,
|
| 120 |
+
weight: Tensor,
|
| 121 |
+
bias: Optional[Tensor] = None,
|
| 122 |
+
return_residual: bool = False,
|
| 123 |
+
process_group: Optional[ProcessGroup] = None,
|
| 124 |
+
sequence_parallel: bool = True,
|
| 125 |
+
):
|
| 126 |
+
dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or (
|
| 127 |
+
x.dtype == torch.float32 and torch.is_autocast_enabled()
|
| 128 |
+
)
|
| 129 |
+
if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible:
|
| 130 |
+
return FusedDenseFunc.apply(
|
| 131 |
+
x, weight, bias, return_residual, process_group, sequence_parallel
|
| 132 |
+
)
|
| 133 |
+
else:
|
| 134 |
+
assert process_group is None
|
| 135 |
+
out = F.linear(x, weight, bias)
|
| 136 |
+
return out if not return_residual else (out, x)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
class FusedDense(nn.Linear):
|
| 140 |
+
def __init__(
|
| 141 |
+
self,
|
| 142 |
+
in_features: int,
|
| 143 |
+
out_features: int,
|
| 144 |
+
bias: bool = True,
|
| 145 |
+
return_residual: bool = False,
|
| 146 |
+
device=None,
|
| 147 |
+
dtype=None,
|
| 148 |
+
) -> None:
|
| 149 |
+
super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype)
|
| 150 |
+
self.return_residual = return_residual
|
| 151 |
+
|
| 152 |
+
def forward(self, x, process_group=None):
|
| 153 |
+
"""
|
| 154 |
+
If process_group is not None, we're doing Tensor Parallel with sequence parallelism:
|
| 155 |
+
we do an all_gather of x before doing the matmul.
|
| 156 |
+
"""
|
| 157 |
+
return fused_dense_func(
|
| 158 |
+
x,
|
| 159 |
+
self.weight,
|
| 160 |
+
self.bias,
|
| 161 |
+
return_residual=self.return_residual,
|
| 162 |
+
process_group=process_group,
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class ColumnParallelLinear(nn.Linear):
|
| 167 |
+
def __init__(
|
| 168 |
+
self,
|
| 169 |
+
in_features: int,
|
| 170 |
+
out_features: int,
|
| 171 |
+
process_group: ProcessGroup,
|
| 172 |
+
bias: bool = True,
|
| 173 |
+
sequence_parallel=True,
|
| 174 |
+
multiple_of=1,
|
| 175 |
+
device=None,
|
| 176 |
+
dtype=None,
|
| 177 |
+
) -> None:
|
| 178 |
+
world_size = torch.distributed.get_world_size(process_group)
|
| 179 |
+
if out_features % multiple_of:
|
| 180 |
+
raise ValueError(f"out_features ({out_features}) must be a multiple of {multiple_of}")
|
| 181 |
+
multiple = out_features // multiple_of
|
| 182 |
+
# We want to split @multiple across world_size, but it could be an uneven split
|
| 183 |
+
div = multiple // world_size
|
| 184 |
+
mod = multiple % world_size
|
| 185 |
+
# The first @mod ranks get @div + 1 copies, the rest get @div copies
|
| 186 |
+
local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
|
| 187 |
+
super().__init__(
|
| 188 |
+
in_features, local_multiple * multiple_of, bias=bias, device=device, dtype=dtype
|
| 189 |
+
)
|
| 190 |
+
self.process_group = process_group
|
| 191 |
+
self.sequence_parallel = sequence_parallel
|
| 192 |
+
|
| 193 |
+
def forward(self, x):
|
| 194 |
+
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
|
| 195 |
+
# we do an all_gather of x before doing the matmul.
|
| 196 |
+
# If not, then the input is already gathered.
|
| 197 |
+
return fused_dense_func(
|
| 198 |
+
x,
|
| 199 |
+
self.weight,
|
| 200 |
+
self.bias,
|
| 201 |
+
process_group=self.process_group,
|
| 202 |
+
sequence_parallel=self.sequence_parallel,
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
class RowParallelLinear(nn.Linear):
|
| 207 |
+
def __init__(
|
| 208 |
+
self,
|
| 209 |
+
in_features: int,
|
| 210 |
+
out_features: int,
|
| 211 |
+
process_group: ProcessGroup,
|
| 212 |
+
bias: bool = True,
|
| 213 |
+
sequence_parallel=True,
|
| 214 |
+
multiple_of=1,
|
| 215 |
+
device=None,
|
| 216 |
+
dtype=None,
|
| 217 |
+
) -> None:
|
| 218 |
+
world_size = torch.distributed.get_world_size(process_group)
|
| 219 |
+
rank = torch.distributed.get_rank(process_group)
|
| 220 |
+
if in_features % multiple_of:
|
| 221 |
+
raise ValueError(f"in_features ({in_features}) must be a multiple of {multiple_of}")
|
| 222 |
+
multiple = in_features // multiple_of
|
| 223 |
+
# We want to split @multiple across world_size, but it could be an uneven split
|
| 224 |
+
div = multiple // world_size
|
| 225 |
+
mod = multiple % world_size
|
| 226 |
+
# The first @mod ranks get @div + 1 copies, the rest get @div copies
|
| 227 |
+
local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
|
| 228 |
+
# Only rank 0 will have bias
|
| 229 |
+
super().__init__(
|
| 230 |
+
local_multiple * multiple_of,
|
| 231 |
+
out_features,
|
| 232 |
+
bias=bias and rank == 0,
|
| 233 |
+
device=device,
|
| 234 |
+
dtype=dtype,
|
| 235 |
+
)
|
| 236 |
+
self.process_group = process_group
|
| 237 |
+
self.sequence_parallel = sequence_parallel
|
| 238 |
+
|
| 239 |
+
def forward(self, x):
|
| 240 |
+
"""
|
| 241 |
+
We're doing Tensor Parallel with sequence parallelism: we do the matmul and then
|
| 242 |
+
a reduce_scatter of the result.
|
| 243 |
+
"""
|
| 244 |
+
out = fused_dense_func(x, self.weight, self.bias)
|
| 245 |
+
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
|
| 246 |
+
return reduce_fn(out, self.process_group)
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
class FusedMLPFunc(torch.autograd.Function):
|
| 250 |
+
@staticmethod
|
| 251 |
+
@custom_fwd
|
| 252 |
+
def forward(
|
| 253 |
+
ctx,
|
| 254 |
+
x,
|
| 255 |
+
weight1,
|
| 256 |
+
bias1,
|
| 257 |
+
weight2,
|
| 258 |
+
bias2,
|
| 259 |
+
activation="gelu_approx",
|
| 260 |
+
save_pre_act=True,
|
| 261 |
+
return_residual=False,
|
| 262 |
+
checkpoint_lvl=0,
|
| 263 |
+
heuristic=0,
|
| 264 |
+
process_group=None,
|
| 265 |
+
sequence_parallel=True,
|
| 266 |
+
):
|
| 267 |
+
"""
|
| 268 |
+
If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
|
| 269 |
+
with sequence parallelism: we do an all_gather of x before doing the matmul.
|
| 270 |
+
If sequence_parallel=False, then the input is already gathered.
|
| 271 |
+
|
| 272 |
+
checkpoint_lvl:
|
| 273 |
+
0: no recomputation in the bwd
|
| 274 |
+
1: recompute gelu_out / relu_out in the bwd
|
| 275 |
+
2: recompute pre_act and gelu_out / relu_out in the bwd
|
| 276 |
+
"""
|
| 277 |
+
assert -1 <= heuristic <= 4
|
| 278 |
+
assert activation in ["gelu_approx", "relu", "sqrelu"]
|
| 279 |
+
if activation == "sqrelu":
|
| 280 |
+
assert heuristic == -1
|
| 281 |
+
if not save_pre_act:
|
| 282 |
+
checkpoint_lvl = 2
|
| 283 |
+
assert checkpoint_lvl in [0, 1, 2]
|
| 284 |
+
ctx.return_residual = return_residual
|
| 285 |
+
ctx.process_group = process_group
|
| 286 |
+
ctx.sequence_parallel = sequence_parallel
|
| 287 |
+
ctx.checkpoint_lvl = checkpoint_lvl
|
| 288 |
+
ctx.activation = activation
|
| 289 |
+
ctx.heuristic = heuristic
|
| 290 |
+
|
| 291 |
+
if torch.is_autocast_enabled():
|
| 292 |
+
x = x.to(dtype=torch.get_autocast_gpu_dtype())
|
| 293 |
+
x = x.contiguous()
|
| 294 |
+
if process_group is not None and sequence_parallel:
|
| 295 |
+
# We want to kick off the all_gather early, before weight dtype conversion
|
| 296 |
+
total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
|
| 297 |
+
else:
|
| 298 |
+
total_x = x
|
| 299 |
+
|
| 300 |
+
if torch.is_autocast_enabled():
|
| 301 |
+
dtype = torch.get_autocast_gpu_dtype()
|
| 302 |
+
weight1, weight2 = [a.to(dtype=dtype) for a in [weight1, weight2]]
|
| 303 |
+
bias1 = bias1.to(dtype=dtype) if bias1 is not None else None
|
| 304 |
+
bias2 = bias2.to(dtype=dtype) if bias2 is not None else None
|
| 305 |
+
weight1 = weight1.contiguous()
|
| 306 |
+
bias1 = bias1.contiguous() if bias1 is not None else None
|
| 307 |
+
weight2 = weight2.contiguous()
|
| 308 |
+
bias2 = bias2.contiguous() if bias2 is not None else None
|
| 309 |
+
if process_group is not None and sequence_parallel:
|
| 310 |
+
handle_x.wait()
|
| 311 |
+
batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
|
| 312 |
+
batch_dim = batch_shape.numel()
|
| 313 |
+
# https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
|
| 314 |
+
if min(batch_dim, n, *weight1.shape, *weight2.shape) > 65535 * 32:
|
| 315 |
+
raise RuntimeError("fused_dense only supports matrix dims <= 2M")
|
| 316 |
+
if heuristic == -1:
|
| 317 |
+
pre_act = F.linear(total_x, weight1, bias1)
|
| 318 |
+
activation_fn = (
|
| 319 |
+
partial(F.gelu, approximate="tanh")
|
| 320 |
+
if activation == "gelu_approx"
|
| 321 |
+
else (sqrelu_fwd if activation == "sqrelu" else F.relu)
|
| 322 |
+
)
|
| 323 |
+
with torch.jit.fuser("fuser2"):
|
| 324 |
+
output1 = activation_fn(pre_act)
|
| 325 |
+
# This is before adding bias1
|
| 326 |
+
# pre_act = F.linear(total_x.reshape(batch_dim, n), weight1)
|
| 327 |
+
# with torch.jit.fuser('fuser2'):
|
| 328 |
+
# output1 = bias_gelu(pre_act, bias1)
|
| 329 |
+
else:
|
| 330 |
+
is_gelu = activation == "gelu_approx"
|
| 331 |
+
output1, *rest = fused_dense_cuda.linear_act_forward(
|
| 332 |
+
total_x.reshape(batch_dim, n), weight1, bias1, is_gelu, save_pre_act, heuristic
|
| 333 |
+
)
|
| 334 |
+
if save_pre_act:
|
| 335 |
+
pre_act = rest[0]
|
| 336 |
+
output2 = F.linear(output1, weight2, bias2)
|
| 337 |
+
if checkpoint_lvl == 0 or (checkpoint_lvl == 1 and activation == "relu"):
|
| 338 |
+
# For RELU the pre_act is very small (just a bit-mask) so we just save it
|
| 339 |
+
ctx.save_for_backward(x, weight1, weight2, pre_act, output1)
|
| 340 |
+
elif checkpoint_lvl == 1:
|
| 341 |
+
ctx.save_for_backward(x, weight1, weight2, pre_act)
|
| 342 |
+
elif checkpoint_lvl == 2:
|
| 343 |
+
ctx.save_for_backward(x, weight1, weight2, bias1)
|
| 344 |
+
output2 = output2.reshape(*batch_shape, output2.shape[-1])
|
| 345 |
+
return output2 if not return_residual else (output2, x)
|
| 346 |
+
|
| 347 |
+
@staticmethod
|
| 348 |
+
@custom_bwd
|
| 349 |
+
def backward(ctx, grad_output, *args):
|
| 350 |
+
grad_output = grad_output.contiguous()
|
| 351 |
+
checkpoint_lvl = ctx.checkpoint_lvl
|
| 352 |
+
activation = ctx.activation
|
| 353 |
+
activation_fn = (
|
| 354 |
+
partial(F.gelu, approximate="tanh")
|
| 355 |
+
if activation == "gelu_approx"
|
| 356 |
+
else (sqrelu_fwd if activation == "sqrelu" else F.relu)
|
| 357 |
+
)
|
| 358 |
+
if ctx.return_residual:
|
| 359 |
+
(grad_input,) = args
|
| 360 |
+
grad_input = grad_input.contiguous()
|
| 361 |
+
process_group = ctx.process_group
|
| 362 |
+
sequence_parallel = ctx.sequence_parallel
|
| 363 |
+
x, weight1, weight2, *rest = ctx.saved_tensors
|
| 364 |
+
if process_group is None or not sequence_parallel:
|
| 365 |
+
total_x = x
|
| 366 |
+
batch_shape = grad_output.shape[:-1]
|
| 367 |
+
batch_dim = batch_shape.numel()
|
| 368 |
+
if checkpoint_lvl in [0, 1]:
|
| 369 |
+
if process_group is not None and sequence_parallel:
|
| 370 |
+
total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
|
| 371 |
+
if checkpoint_lvl == 0 or (checkpoint_lvl == 1 and activation == "relu"):
|
| 372 |
+
pre_act, output1 = rest
|
| 373 |
+
elif checkpoint_lvl == 1:
|
| 374 |
+
(pre_act,) = rest
|
| 375 |
+
with torch.jit.fuser("fuser2"):
|
| 376 |
+
output1 = activation_fn(pre_act)
|
| 377 |
+
elif checkpoint_lvl == 2:
|
| 378 |
+
(bias1,) = rest
|
| 379 |
+
if process_group is not None and sequence_parallel:
|
| 380 |
+
total_x, _ = all_gather_raw(x, process_group)
|
| 381 |
+
if ctx.heuristic == -1:
|
| 382 |
+
pre_act = F.linear(total_x, weight1, bias1)
|
| 383 |
+
with torch.jit.fuser("fuser2"):
|
| 384 |
+
output1 = activation_fn(pre_act)
|
| 385 |
+
else:
|
| 386 |
+
output1, pre_act = fused_dense_cuda.linear_act_forward(
|
| 387 |
+
total_x.reshape(batch_dim, total_x.shape[-1]),
|
| 388 |
+
weight1,
|
| 389 |
+
bias1,
|
| 390 |
+
activation == "gelu_approx",
|
| 391 |
+
True,
|
| 392 |
+
ctx.heuristic,
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
|
| 396 |
+
output1 = output1.reshape(batch_dim, output1.shape[-1])
|
| 397 |
+
pre_act = pre_act.reshape(batch_dim, pre_act.shape[-1])
|
| 398 |
+
if ctx.needs_input_grad[3]:
|
| 399 |
+
grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(
|
| 400 |
+
output1, grad_output, ctx.needs_input_grad[4]
|
| 401 |
+
)
|
| 402 |
+
else:
|
| 403 |
+
grad_weight2 = None
|
| 404 |
+
grad_bias2 = grad_output if ctx.needs_input_grad[4] else None
|
| 405 |
+
if ctx.heuristic == -1:
|
| 406 |
+
# grad_pre_act = matmul_dgelu(grad_output, weight2, pre_act)
|
| 407 |
+
grad_output1 = F.linear(grad_output, weight2.t())
|
| 408 |
+
activation_grad_fn = (
|
| 409 |
+
gelu_bwd
|
| 410 |
+
if activation == "gelu_approx"
|
| 411 |
+
else (sqrelu_bwd if activation == "sqrelu" else relu_bwd)
|
| 412 |
+
)
|
| 413 |
+
with torch.jit.fuser("fuser2"):
|
| 414 |
+
grad_pre_act = activation_grad_fn(grad_output1, pre_act)
|
| 415 |
+
else:
|
| 416 |
+
# The cublasLt epilogue has to compute both gelu/relu grad and bias grad, we can't
|
| 417 |
+
# just compute gelu/relu grad
|
| 418 |
+
grad_pre_act, grad_bias1 = fused_dense_cuda.bias_act_linear_dgrad_bgrad(
|
| 419 |
+
weight2, grad_output, pre_act, activation == "gelu_approx", ctx.heuristic
|
| 420 |
+
)
|
| 421 |
+
if not ctx.needs_input_grad[2]:
|
| 422 |
+
grad_bias1 = None
|
| 423 |
+
if ctx.needs_input_grad[0]:
|
| 424 |
+
if not ctx.return_residual:
|
| 425 |
+
grad_input = F.linear(grad_pre_act, weight1.t())
|
| 426 |
+
else:
|
| 427 |
+
grad_input = torch.addmm(
|
| 428 |
+
grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_pre_act, weight1
|
| 429 |
+
)
|
| 430 |
+
grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
|
| 431 |
+
if process_group is not None:
|
| 432 |
+
reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
|
| 433 |
+
grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True)
|
| 434 |
+
else:
|
| 435 |
+
grad_input = None
|
| 436 |
+
if ctx.heuristic == -1:
|
| 437 |
+
if ctx.needs_input_grad[1]:
|
| 438 |
+
if process_group is not None and sequence_parallel and checkpoint_lvl != 2:
|
| 439 |
+
handle_x.wait()
|
| 440 |
+
grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_wgrad(
|
| 441 |
+
total_x.reshape(batch_dim, total_x.shape[-1]),
|
| 442 |
+
grad_pre_act,
|
| 443 |
+
ctx.needs_input_grad[2],
|
| 444 |
+
)
|
| 445 |
+
else:
|
| 446 |
+
grad_weight1 = None
|
| 447 |
+
grad_bias1 = grad_pre_act if ctx.needs_input_grad[2] else None
|
| 448 |
+
else:
|
| 449 |
+
if ctx.needs_input_grad[1]:
|
| 450 |
+
if process_group is not None and sequence_parallel and checkpoint_lvl != 2:
|
| 451 |
+
handle_x.wait()
|
| 452 |
+
grad_weight1 = F.linear(
|
| 453 |
+
grad_pre_act.t(), total_x.reshape(batch_dim, total_x.shape[-1]).t()
|
| 454 |
+
)
|
| 455 |
+
else:
|
| 456 |
+
grad_weight1 = None
|
| 457 |
+
if process_group is not None and ctx.needs_input_grad[0]:
|
| 458 |
+
handle_grad_input.wait()
|
| 459 |
+
return (
|
| 460 |
+
grad_input,
|
| 461 |
+
grad_weight1,
|
| 462 |
+
grad_bias1,
|
| 463 |
+
grad_weight2,
|
| 464 |
+
grad_bias2,
|
| 465 |
+
None,
|
| 466 |
+
None,
|
| 467 |
+
None,
|
| 468 |
+
None,
|
| 469 |
+
None,
|
| 470 |
+
None,
|
| 471 |
+
None,
|
| 472 |
+
)
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
def fused_mlp_func(
|
| 476 |
+
x: Tensor,
|
| 477 |
+
weight1: Tensor,
|
| 478 |
+
weight2: Tensor,
|
| 479 |
+
bias1: Optional[Tensor] = None,
|
| 480 |
+
bias2: Optional[Tensor] = None,
|
| 481 |
+
activation: str = "gelu_approx",
|
| 482 |
+
save_pre_act: bool = True,
|
| 483 |
+
return_residual: bool = False,
|
| 484 |
+
checkpoint_lvl: int = 0,
|
| 485 |
+
heuristic: int = 0,
|
| 486 |
+
process_group: Optional[ProcessGroup] = None,
|
| 487 |
+
sequence_parallel: bool = True,
|
| 488 |
+
):
|
| 489 |
+
assert activation in ["gelu_approx", "relu", "sqrelu"]
|
| 490 |
+
dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or (
|
| 491 |
+
x.dtype == torch.float32 and torch.is_autocast_enabled()
|
| 492 |
+
)
|
| 493 |
+
# If we save pre-activation, dimension must be divisible by 128 (relu) or 8 (gelu)
|
| 494 |
+
dim_eligible = not save_pre_act or (x.shape[-1] % (128 if activation == "relu" else 8) == 0)
|
| 495 |
+
if (
|
| 496 |
+
x.is_cuda
|
| 497 |
+
and weight1.is_cuda
|
| 498 |
+
and weight2.is_cuda
|
| 499 |
+
and (bias1 is None or bias1.is_cuda)
|
| 500 |
+
and (bias2 is None or bias2.is_cuda)
|
| 501 |
+
and dtype_eligible
|
| 502 |
+
and dim_eligible
|
| 503 |
+
):
|
| 504 |
+
return FusedMLPFunc.apply(
|
| 505 |
+
x,
|
| 506 |
+
weight1,
|
| 507 |
+
bias1,
|
| 508 |
+
weight2,
|
| 509 |
+
bias2,
|
| 510 |
+
activation,
|
| 511 |
+
save_pre_act,
|
| 512 |
+
return_residual,
|
| 513 |
+
checkpoint_lvl,
|
| 514 |
+
heuristic,
|
| 515 |
+
process_group,
|
| 516 |
+
sequence_parallel,
|
| 517 |
+
)
|
| 518 |
+
else:
|
| 519 |
+
assert process_group is None
|
| 520 |
+
pre_act = F.linear(x, weight1, bias1)
|
| 521 |
+
activation_fn = (
|
| 522 |
+
partial(F.gelu, approximate="tanh")
|
| 523 |
+
if activation == "gelu_approx"
|
| 524 |
+
else partial(F.relu, inplace=True)
|
| 525 |
+
)
|
| 526 |
+
output1 = activation_fn(pre_act)
|
| 527 |
+
output2 = F.linear(output1, weight2, bias2)
|
| 528 |
+
return output2 if not return_residual else (output2, x)
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
class FusedMLP(nn.Module):
|
| 532 |
+
def __init__(
|
| 533 |
+
self,
|
| 534 |
+
in_features,
|
| 535 |
+
hidden_features=None,
|
| 536 |
+
out_features=None,
|
| 537 |
+
bias1=True,
|
| 538 |
+
bias2=True,
|
| 539 |
+
activation="gelu_approx",
|
| 540 |
+
return_residual=False,
|
| 541 |
+
checkpoint_lvl=0,
|
| 542 |
+
heuristic="auto",
|
| 543 |
+
device=None,
|
| 544 |
+
dtype=None,
|
| 545 |
+
):
|
| 546 |
+
"""
|
| 547 |
+
If process_group is not None, we're doing Tensor Parallel with sequence parallelism:
|
| 548 |
+
we do an all_gather of x before doing the matmul, gelu, then matmul.
|
| 549 |
+
Finally we do a reduce_scatter of the output.
|
| 550 |
+
|
| 551 |
+
checkpoint_lvl (increasing lvl means slower but more memory saving):
|
| 552 |
+
0: no recomputation in the bwd
|
| 553 |
+
1: recompute gelu_out in the bwd
|
| 554 |
+
2: recompute pre_act and gelu_out in the bwd
|
| 555 |
+
heuristic:
|
| 556 |
+
-1: don't fuse gemm + gelu (separate kernel)
|
| 557 |
+
0..4: use this heuristic for the algo section in the fused gemm + gelu
|
| 558 |
+
'auto': heuristic will be picked automatically:
|
| 559 |
+
For CUDA >= 11.8, we set heuristic=0 for both fp16 and bf16 for best perf.
|
| 560 |
+
For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16.
|
| 561 |
+
For H100, we set heuristic=-1 for both fp16 and bf16 as the fused cuBlasLt implementation
|
| 562 |
+
is slower than the unfused version.
|
| 563 |
+
return_residual: whether to return the input x along with the output. This is for
|
| 564 |
+
performance reason: for post-norm architecture, returning the input allows us
|
| 565 |
+
to fuse the backward of nn.Linear with the residual connection.
|
| 566 |
+
"""
|
| 567 |
+
assert checkpoint_lvl in [0, 1, 2]
|
| 568 |
+
assert activation in ["gelu_approx", "relu", "sqrelu"]
|
| 569 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 570 |
+
super().__init__()
|
| 571 |
+
out_features = out_features or in_features
|
| 572 |
+
hidden_features = hidden_features or in_features * 4
|
| 573 |
+
self.activation = activation
|
| 574 |
+
self.return_residual = return_residual
|
| 575 |
+
self.checkpoint_lvl = checkpoint_lvl
|
| 576 |
+
self.heuristic = heuristic if activation != "sqrelu" else -1
|
| 577 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)
|
| 578 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
|
| 579 |
+
|
| 580 |
+
def forward(self, x, process_group=None):
|
| 581 |
+
dtype = x.dtype if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype()
|
| 582 |
+
if self.heuristic == "auto":
|
| 583 |
+
if self.activation == "gelu_approx":
|
| 584 |
+
if torch.cuda.get_device_capability("cuda") == (9, 0):
|
| 585 |
+
heuristic = -1
|
| 586 |
+
else:
|
| 587 |
+
cuda_ver = tuple(map(int, torch.version.cuda.split(".")))
|
| 588 |
+
heuristic = 0 if cuda_ver >= (11, 8) else (1 if dtype == torch.float16 else -1)
|
| 589 |
+
else:
|
| 590 |
+
heuristic = 0
|
| 591 |
+
else:
|
| 592 |
+
heuristic = self.heuristic
|
| 593 |
+
out = fused_mlp_func(
|
| 594 |
+
x,
|
| 595 |
+
self.fc1.weight,
|
| 596 |
+
self.fc2.weight,
|
| 597 |
+
self.fc1.bias,
|
| 598 |
+
self.fc2.bias,
|
| 599 |
+
activation=self.activation,
|
| 600 |
+
save_pre_act=self.training,
|
| 601 |
+
return_residual=self.return_residual,
|
| 602 |
+
checkpoint_lvl=self.checkpoint_lvl,
|
| 603 |
+
heuristic=heuristic,
|
| 604 |
+
process_group=process_group,
|
| 605 |
+
)
|
| 606 |
+
if self.return_residual:
|
| 607 |
+
out, x = out
|
| 608 |
+
if process_group is not None:
|
| 609 |
+
out = reduce_scatter(out, process_group)
|
| 610 |
+
return out if not self.return_residual else (out, x)
|
| 611 |
+
|
| 612 |
+
|
| 613 |
+
class ParallelFusedMLP(nn.Module):
|
| 614 |
+
def __init__(
|
| 615 |
+
self,
|
| 616 |
+
in_features,
|
| 617 |
+
hidden_features=None,
|
| 618 |
+
out_features=None,
|
| 619 |
+
activation="gelu_approx",
|
| 620 |
+
process_group: ProcessGroup = None,
|
| 621 |
+
bias1=True,
|
| 622 |
+
bias2=True,
|
| 623 |
+
sequence_parallel=True,
|
| 624 |
+
checkpoint_lvl=0,
|
| 625 |
+
heuristic="auto",
|
| 626 |
+
device=None,
|
| 627 |
+
dtype=None,
|
| 628 |
+
):
|
| 629 |
+
"""
|
| 630 |
+
process_group is required. We're doing Tensor Parallel with sequence parallelism:
|
| 631 |
+
we do an all_gather of x before doing the matmul, gelu, then matmul.
|
| 632 |
+
Finally we do a reduce_scatter of the output.
|
| 633 |
+
|
| 634 |
+
checkpoint_lvl (increasing lvl means slower but more memory saving):
|
| 635 |
+
0: no recomputation in the bwd
|
| 636 |
+
1: recompute gelu_out in the bwd
|
| 637 |
+
2: recompute pre_act and gelu_out in the bwd
|
| 638 |
+
heuristic:
|
| 639 |
+
-1: don't fuse gemm + gelu (separate kernel)
|
| 640 |
+
0..4: use this heuristic for the algo section in the fused gemm + gelu
|
| 641 |
+
'auto': heuristic will be picked automatically:
|
| 642 |
+
For CUDA >= 11.8, we set heuristic=0 for both fp16 and bf16 for best perf.
|
| 643 |
+
For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16.
|
| 644 |
+
"""
|
| 645 |
+
assert checkpoint_lvl in [0, 1, 2]
|
| 646 |
+
assert activation in ["gelu_approx", "relu", "sqrelu"]
|
| 647 |
+
assert process_group is not None
|
| 648 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 649 |
+
super().__init__()
|
| 650 |
+
out_features = out_features or in_features
|
| 651 |
+
hidden_features = hidden_features or in_features * 4
|
| 652 |
+
self.activation = activation
|
| 653 |
+
self.process_group = process_group
|
| 654 |
+
self.sequence_parallel = sequence_parallel
|
| 655 |
+
self.checkpoint_lvl = checkpoint_lvl
|
| 656 |
+
self.heuristic = heuristic if activation != "sqrelu" else -1
|
| 657 |
+
self.fc1 = ColumnParallelLinear(
|
| 658 |
+
in_features, hidden_features, process_group, bias=bias1, **factory_kwargs
|
| 659 |
+
)
|
| 660 |
+
self.fc2 = RowParallelLinear(
|
| 661 |
+
hidden_features, out_features, process_group, bias=bias2, **factory_kwargs
|
| 662 |
+
)
|
| 663 |
+
|
| 664 |
+
def forward(self, x):
|
| 665 |
+
dtype = x.dtype if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype()
|
| 666 |
+
if self.heuristic == "auto":
|
| 667 |
+
if self.activation == "gelu_approx":
|
| 668 |
+
cuda_ver = tuple(map(int, torch.version.cuda.split(".")))
|
| 669 |
+
heuristic = 0 if cuda_ver >= (11, 8) else (1 if dtype == torch.float16 else -1)
|
| 670 |
+
else:
|
| 671 |
+
heuristic = 0
|
| 672 |
+
else:
|
| 673 |
+
heuristic = self.heuristic
|
| 674 |
+
out = fused_mlp_func(
|
| 675 |
+
x,
|
| 676 |
+
self.fc1.weight,
|
| 677 |
+
self.fc2.weight,
|
| 678 |
+
self.fc1.bias,
|
| 679 |
+
self.fc2.bias,
|
| 680 |
+
activation=self.activation,
|
| 681 |
+
save_pre_act=self.training,
|
| 682 |
+
checkpoint_lvl=self.checkpoint_lvl,
|
| 683 |
+
heuristic=heuristic,
|
| 684 |
+
process_group=self.process_group,
|
| 685 |
+
sequence_parallel=self.sequence_parallel,
|
| 686 |
+
)
|
| 687 |
+
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
|
| 688 |
+
return reduce_fn(out, self.process_group)
|