koichi12 commited on
Commit
d33aea4
·
verified ·
1 Parent(s): 640f355

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .venv/lib/python3.11/site-packages/xformers/_flash_attn/__pycache__/__init__.cpython-311.pyc +0 -0
  2. .venv/lib/python3.11/site-packages/xformers/_flash_attn/__pycache__/bert_padding.cpython-311.pyc +0 -0
  3. .venv/lib/python3.11/site-packages/xformers/_flash_attn/__pycache__/flash_attn_interface.cpython-311.pyc +0 -0
  4. .venv/lib/python3.11/site-packages/xformers/_flash_attn/__pycache__/flash_attn_triton.cpython-311.pyc +0 -0
  5. .venv/lib/python3.11/site-packages/xformers/_flash_attn/__pycache__/flash_attn_triton_og.cpython-311.pyc +0 -0
  6. .venv/lib/python3.11/site-packages/xformers/_flash_attn/__pycache__/flash_blocksparse_attention.cpython-311.pyc +0 -0
  7. .venv/lib/python3.11/site-packages/xformers/_flash_attn/__pycache__/flash_blocksparse_attn_interface.cpython-311.pyc +0 -0
  8. .venv/lib/python3.11/site-packages/xformers/_flash_attn/__pycache__/fused_softmax.cpython-311.pyc +0 -0
  9. .venv/lib/python3.11/site-packages/xformers/_flash_attn/layers/__init__.py +0 -0
  10. .venv/lib/python3.11/site-packages/xformers/_flash_attn/layers/__pycache__/__init__.cpython-311.pyc +0 -0
  11. .venv/lib/python3.11/site-packages/xformers/_flash_attn/layers/__pycache__/patch_embed.cpython-311.pyc +0 -0
  12. .venv/lib/python3.11/site-packages/xformers/_flash_attn/layers/__pycache__/rotary.cpython-311.pyc +0 -0
  13. .venv/lib/python3.11/site-packages/xformers/_flash_attn/layers/patch_embed.py +67 -0
  14. .venv/lib/python3.11/site-packages/xformers/_flash_attn/layers/rotary.py +481 -0
  15. .venv/lib/python3.11/site-packages/xformers/_flash_attn/losses/__init__.py +0 -0
  16. .venv/lib/python3.11/site-packages/xformers/_flash_attn/losses/__pycache__/__init__.cpython-311.pyc +0 -0
  17. .venv/lib/python3.11/site-packages/xformers/_flash_attn/losses/__pycache__/cross_entropy.cpython-311.pyc +0 -0
  18. .venv/lib/python3.11/site-packages/xformers/_flash_attn/losses/cross_entropy.py +85 -0
  19. .venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__init__.py +0 -0
  20. .venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/__init__.cpython-311.pyc +0 -0
  21. .venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/baichuan.cpython-311.pyc +0 -0
  22. .venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/bert.cpython-311.pyc +0 -0
  23. .venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/bigcode.cpython-311.pyc +0 -0
  24. .venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/btlm.cpython-311.pyc +0 -0
  25. .venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/falcon.cpython-311.pyc +0 -0
  26. .venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/gpt.cpython-311.pyc +0 -0
  27. .venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/gpt_neox.cpython-311.pyc +0 -0
  28. .venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/gptj.cpython-311.pyc +0 -0
  29. .venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/llama.cpython-311.pyc +0 -0
  30. .venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/opt.cpython-311.pyc +0 -0
  31. .venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/vit.cpython-311.pyc +0 -0
  32. .venv/lib/python3.11/site-packages/xformers/_flash_attn/models/baichuan.py +151 -0
  33. .venv/lib/python3.11/site-packages/xformers/_flash_attn/models/bert.py +764 -0
  34. .venv/lib/python3.11/site-packages/xformers/_flash_attn/models/bigcode.py +233 -0
  35. .venv/lib/python3.11/site-packages/xformers/_flash_attn/models/btlm.py +102 -0
  36. .venv/lib/python3.11/site-packages/xformers/_flash_attn/models/falcon.py +143 -0
  37. .venv/lib/python3.11/site-packages/xformers/_flash_attn/models/gpt.py +1080 -0
  38. .venv/lib/python3.11/site-packages/xformers/_flash_attn/models/gpt_neox.py +124 -0
  39. .venv/lib/python3.11/site-packages/xformers/_flash_attn/models/gptj.py +109 -0
  40. .venv/lib/python3.11/site-packages/xformers/_flash_attn/models/llama.py +422 -0
  41. .venv/lib/python3.11/site-packages/xformers/_flash_attn/models/opt.py +116 -0
  42. .venv/lib/python3.11/site-packages/xformers/_flash_attn/models/vit.py +373 -0
  43. .venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/__init__.py +0 -0
  44. .venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/__pycache__/__init__.cpython-311.pyc +0 -0
  45. .venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/__pycache__/activations.cpython-311.pyc +0 -0
  46. .venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/__pycache__/fused_dense.cpython-311.pyc +0 -0
  47. .venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/__pycache__/layer_norm.cpython-311.pyc +0 -0
  48. .venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/__pycache__/rms_norm.cpython-311.pyc +0 -0
  49. .venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/activations.py +135 -0
  50. .venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/fused_dense.py +688 -0
.venv/lib/python3.11/site-packages/xformers/_flash_attn/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (606 Bytes). View file
 
.venv/lib/python3.11/site-packages/xformers/_flash_attn/__pycache__/bert_padding.cpython-311.pyc ADDED
Binary file (11 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/_flash_attn/__pycache__/flash_attn_interface.cpython-311.pyc ADDED
Binary file (46.3 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/_flash_attn/__pycache__/flash_attn_triton.cpython-311.pyc ADDED
Binary file (44.5 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/_flash_attn/__pycache__/flash_attn_triton_og.cpython-311.pyc ADDED
Binary file (16.2 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/_flash_attn/__pycache__/flash_blocksparse_attention.cpython-311.pyc ADDED
Binary file (8.15 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/_flash_attn/__pycache__/flash_blocksparse_attn_interface.cpython-311.pyc ADDED
Binary file (7.53 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/_flash_attn/__pycache__/fused_softmax.cpython-311.pyc ADDED
Binary file (9.45 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/_flash_attn/layers/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/xformers/_flash_attn/layers/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (200 Bytes). View file
 
.venv/lib/python3.11/site-packages/xformers/_flash_attn/layers/__pycache__/patch_embed.cpython-311.pyc ADDED
Binary file (3.34 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/_flash_attn/layers/__pycache__/rotary.cpython-311.pyc ADDED
Binary file (20.1 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/_flash_attn/layers/patch_embed.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # We use the same API as https://github.com/rwightman/pytorch-image-models/blob/v0.6.11/timm/models/layers/patch_embed.py
2
+ # But we use nn.Linear instead of Conv2d and it's about 8x faster.
3
+
4
+ from functools import partial
5
+
6
+ import torch.nn as nn
7
+ from einops import rearrange
8
+ from torch import _assert
9
+ from torch.nn.modules.utils import _pair
10
+
11
+ try:
12
+ from flash_attn.ops.fused_dense import FusedDense
13
+ except ImportError:
14
+ FusedDense = None
15
+
16
+
17
+ class PatchEmbed(nn.Module):
18
+ """2D Image to Patch Embedding"""
19
+
20
+ def __init__(
21
+ self,
22
+ img_size=224,
23
+ patch_size=16,
24
+ in_chans=3,
25
+ embed_dim=768,
26
+ norm_layer=None,
27
+ flatten=True,
28
+ bias=True,
29
+ fused_bias_fc=False,
30
+ ):
31
+ super().__init__()
32
+ img_size = _pair(img_size)
33
+ patch_size = _pair(patch_size)
34
+ self.img_size = img_size
35
+ self.patch_size = patch_size
36
+ self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
37
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
38
+ self.flatten = flatten
39
+ if fused_bias_fc and FusedDense is None:
40
+ raise ImportError("fused_dense is not installed")
41
+
42
+ linear_cls = nn.Linear if not fused_bias_fc or not bias else FusedDense
43
+ self.proj = linear_cls(in_chans * patch_size[0] * patch_size[1], embed_dim, bias=bias)
44
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
45
+
46
+ def forward(self, x):
47
+ _, _, H, W = x.shape
48
+ _assert(
49
+ H == self.img_size[0],
50
+ f"Input image height ({H}) doesn't match model ({self.img_size[0]}).",
51
+ )
52
+ _assert(
53
+ W == self.img_size[1],
54
+ f"Input image width ({W}) doesn't match model ({self.img_size[1]}).",
55
+ )
56
+ x = self.proj(
57
+ rearrange(
58
+ x,
59
+ "b c (h p1) (w p2) -> b h w (c p1 p2)",
60
+ p1=self.patch_size[0],
61
+ p2=self.patch_size[1],
62
+ )
63
+ )
64
+ if self.flatten:
65
+ x = rearrange(x, "b h w c -> b (h w) c")
66
+ x = self.norm(x)
67
+ return x
.venv/lib/python3.11/site-packages/xformers/_flash_attn/layers/rotary.py ADDED
@@ -0,0 +1,481 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+
3
+ import math
4
+ from typing import Optional, Tuple, Union
5
+
6
+ import torch
7
+ from einops import rearrange, repeat
8
+ from flash_attn.ops.triton.rotary import apply_rotary
9
+
10
+
11
+ def rotate_half(x, interleaved=False):
12
+ if not interleaved:
13
+ x1, x2 = x.chunk(2, dim=-1)
14
+ return torch.cat((-x2, x1), dim=-1)
15
+ else:
16
+ x1, x2 = x[..., ::2], x[..., 1::2]
17
+ return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2)
18
+
19
+
20
+ def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
21
+ """
22
+ x: (batch_size, seqlen, nheads, headdim)
23
+ cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
24
+ """
25
+ ro_dim = cos.shape[-1] * 2
26
+ assert ro_dim <= x.shape[-1]
27
+ cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
28
+ sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
29
+ return torch.cat(
30
+ [x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]],
31
+ dim=-1,
32
+ )
33
+
34
+
35
+ class ApplyRotaryEmb(torch.autograd.Function):
36
+ @staticmethod
37
+ def forward(
38
+ ctx,
39
+ x,
40
+ cos,
41
+ sin,
42
+ interleaved=False,
43
+ inplace=False,
44
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
45
+ cu_seqlens: Optional[torch.Tensor] = None,
46
+ max_seqlen: Optional[int] = None,
47
+ ):
48
+ out = apply_rotary(
49
+ x,
50
+ cos,
51
+ sin,
52
+ seqlen_offsets=seqlen_offsets,
53
+ cu_seqlens=cu_seqlens,
54
+ max_seqlen=max_seqlen,
55
+ interleaved=interleaved,
56
+ inplace=inplace,
57
+ )
58
+ if isinstance(seqlen_offsets, int):
59
+ ctx.save_for_backward(cos, sin, cu_seqlens) # Can't save int with save_for_backward
60
+ ctx.seqlen_offsets = seqlen_offsets
61
+ else:
62
+ ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)
63
+ ctx.seqlen_offsets = None
64
+ ctx.interleaved = interleaved
65
+ ctx.inplace = inplace
66
+ ctx.max_seqlen = max_seqlen
67
+ return out if not inplace else x
68
+
69
+ @staticmethod
70
+ def backward(ctx, do):
71
+ seqlen_offsets = ctx.seqlen_offsets
72
+ if seqlen_offsets is None:
73
+ cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors
74
+ else:
75
+ cos, sin, cu_seqlens = ctx.saved_tensors
76
+ # TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with
77
+ # "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works.
78
+ if not ctx.interleaved and not ctx.inplace:
79
+ do = do.clone()
80
+ dx = apply_rotary(
81
+ do,
82
+ cos,
83
+ sin,
84
+ seqlen_offsets=seqlen_offsets,
85
+ cu_seqlens=cu_seqlens,
86
+ max_seqlen=ctx.max_seqlen,
87
+ interleaved=ctx.interleaved,
88
+ inplace=ctx.inplace,
89
+ conjugate=True,
90
+ )
91
+ return dx, None, None, None, None, None, None, None
92
+
93
+
94
+ def apply_rotary_emb(
95
+ x,
96
+ cos,
97
+ sin,
98
+ interleaved=False,
99
+ inplace=False,
100
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
101
+ cu_seqlens: Optional[torch.Tensor] = None,
102
+ max_seqlen: Optional[int] = None,
103
+ ):
104
+ """
105
+ Arguments:
106
+ x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
107
+ else (total_seqlen, nheads, headdim)
108
+ cos, sin: (seqlen_rotary, rotary_dim / 2)
109
+ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
110
+ of 1st half and 2nd half (GPT-NeoX style).
111
+ inplace: if True, apply rotary embedding in-place.
112
+ seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
113
+ Most commonly used in inference when we have KV cache.
114
+ cu_seqlens: (batch + 1,) or None
115
+ max_seqlen: int
116
+ Return:
117
+ out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
118
+ else (total_seqlen, nheads, headdim)
119
+ rotary_dim must be <= headdim
120
+ Apply rotary embedding to the first rotary_dim of x.
121
+ """
122
+ return ApplyRotaryEmb.apply(
123
+ x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen
124
+ )
125
+
126
+
127
+ # For backward compatibility
128
+ apply_rotary_emb_func = apply_rotary_emb
129
+
130
+
131
+ class ApplyRotaryEmbQKV_(torch.autograd.Function):
132
+ @staticmethod
133
+ def forward(
134
+ ctx,
135
+ qkv,
136
+ cos,
137
+ sin,
138
+ cos_k=None,
139
+ sin_k=None,
140
+ interleaved=False,
141
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
142
+ ):
143
+ batch, seqlen, three, nheads, headdim = qkv.shape
144
+ assert three == 3
145
+ if cos_k is None and sin_k is None and qkv.is_contiguous():
146
+ # Call 1 kernel instead of 2 kernels
147
+ # We need qkv to be contiguous so that when we reshape to combine (3, nheads)
148
+ # dimensions, we get the same tensor
149
+ # qk = rearrange(qkv[:, :, :2], "b s t h d -> b s (t h) d")
150
+ qk = qkv[:, :, :2].reshape(batch, seqlen, -1, headdim)
151
+ apply_rotary(
152
+ qk, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True
153
+ )
154
+ else:
155
+ cos_k = cos if cos_k is None else cos_k
156
+ sin_k = sin if sin_k is None else sin_k
157
+ q, k = qkv[:, :, 0], qkv[:, :, 1]
158
+ apply_rotary(q, cos, sin, seqlen_offsets, interleaved=interleaved, inplace=True)
159
+ apply_rotary(k, cos_k, sin_k, seqlen_offsets, interleaved=interleaved, inplace=True)
160
+ ctx.save_for_backward(cos, sin, cos_k, sin_k)
161
+ if isinstance(seqlen_offsets, int):
162
+ ctx.save_for_backward(cos, sin, cos_k, sin_k)
163
+ ctx.seqlen_offsets = seqlen_offsets
164
+ else:
165
+ ctx.save_for_backward(cos, sin, cos_k, sin_k, seqlen_offsets)
166
+ ctx.seqlen_offsets = None
167
+ ctx.interleaved = interleaved
168
+ return qkv
169
+
170
+ @staticmethod
171
+ def backward(ctx, dqkv):
172
+ seqlen_offsets = ctx.seqlen_offsets
173
+ if seqlen_offsets is None:
174
+ cos, sin, cos_k, sin_k, seqlen_offsets = ctx.saved_tensors
175
+ else:
176
+ cos, sin, cos_k, sin_k = ctx.saved_tensors
177
+ if cos_k is None and sin_k is None and dqkv.is_contiguous():
178
+ # Call 1 kernel instead of 2 kernels
179
+ # We need dqkv to be contiguous so that when we reshape to combine (3, nheads)
180
+ # dimensions, we get the same tensor
181
+ dqk = rearrange(dqkv[:, :, :2], "b s t h d -> b s (t h) d")
182
+ apply_rotary(
183
+ dqk,
184
+ cos,
185
+ sin,
186
+ seqlen_offsets=seqlen_offsets,
187
+ interleaved=ctx.interleaved,
188
+ inplace=True,
189
+ conjugate=True,
190
+ )
191
+ else:
192
+ cos_k = cos if cos_k is None else cos_k
193
+ sin_k = sin if sin_k is None else sin_k
194
+ dq, dk = dqkv[:, :, 0], dqkv[:, :, 1]
195
+ apply_rotary(
196
+ dq, cos, sin, seqlen_offsets, interleaved=ctx.interleaved, inplace=True, conjugate=True
197
+ )
198
+ apply_rotary(
199
+ dk,
200
+ cos_k,
201
+ sin_k,
202
+ seqlen_offsets,
203
+ interleaved=ctx.interleaved,
204
+ inplace=True,
205
+ conjugate=True,
206
+ )
207
+ return dqkv, None, None, None, None, None, None
208
+
209
+
210
+ def apply_rotary_emb_qkv_(
211
+ qkv,
212
+ cos,
213
+ sin,
214
+ cos_k=None,
215
+ sin_k=None,
216
+ interleaved=False,
217
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
218
+ ):
219
+ """
220
+ Arguments:
221
+ qkv: (batch_size, seqlen, 3, nheads, headdim)
222
+ cos, sin: (seqlen, rotary_dim / 2)
223
+ cos_k, sin_k: (seqlen, rotary_dim / 2), optional
224
+ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
225
+ 1st half and 2nd half (GPT-NeoX style).
226
+ seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount.
227
+ Most commonly used in inference when we have KV cache.
228
+ Return:
229
+ qkv: (batch_size, seqlen, 3, nheads, headdim)
230
+ rotary_dim must be <= headdim
231
+ Apply rotary embedding *inplace* to the first rotary_dim of Q and K.
232
+ """
233
+ return ApplyRotaryEmbQKV_.apply(qkv, cos, sin, cos_k, sin_k, interleaved, seqlen_offsets)
234
+
235
+
236
+ class ApplyRotaryEmbKV_(torch.autograd.Function):
237
+ @staticmethod
238
+ def forward(ctx, kv, cos, sin, interleaved=False, seqlen_offsets: Union[int, torch.Tensor] = 0):
239
+ batch, seqlen, two, nheads, headdim = kv.shape
240
+ assert two == 2
241
+ k = kv[:, :, 0]
242
+ apply_rotary(
243
+ k, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True
244
+ )
245
+ if isinstance(seqlen_offsets, int):
246
+ ctx.save_for_backward(cos, sin) # Can't save int with save_for_backward
247
+ ctx.seqlen_offsets = seqlen_offsets
248
+ else:
249
+ ctx.save_for_backward(cos, sin, seqlen_offsets)
250
+ ctx.seqlen_offsets = None
251
+ ctx.interleaved = interleaved
252
+ return kv
253
+
254
+ @staticmethod
255
+ def backward(ctx, dkv):
256
+ seqlen_offsets = ctx.seqlen_offsets
257
+ if seqlen_offsets is None:
258
+ cos, sin, seqlen_offsets = ctx.saved_tensors
259
+ else:
260
+ cos, sin = ctx.saved_tensors
261
+ apply_rotary(
262
+ dkv[:, :, 0],
263
+ cos,
264
+ sin,
265
+ seqlen_offsets=seqlen_offsets,
266
+ interleaved=ctx.interleaved,
267
+ inplace=True,
268
+ conjugate=True,
269
+ )
270
+ return dkv, None, None, None, None
271
+
272
+
273
+ apply_rotary_emb_kv_ = ApplyRotaryEmbKV_.apply
274
+
275
+
276
+ def apply_rotary_emb_kv_(
277
+ kv,
278
+ cos,
279
+ sin,
280
+ interleaved=False,
281
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
282
+ ):
283
+ """
284
+ Arguments:
285
+ kv: (batch_size, seqlen, 2, nheads, headdim)
286
+ cos, sin: (seqlen, rotary_dim / 2)
287
+ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
288
+ 1st half and 2nd half (GPT-NeoX style).
289
+ seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount.
290
+ Most commonly used in inference when we have KV cache.
291
+ Return:
292
+ kv: (batch_size, seqlen, 2, nheads, headdim)
293
+ rotary_dim must be <= headdim
294
+ Apply rotary embedding *inplace* to the first rotary_dim of K.
295
+ """
296
+ return ApplyRotaryEmbKV_.apply(kv, cos, sin, interleaved, seqlen_offsets)
297
+
298
+
299
+ class RotaryEmbedding(torch.nn.Module):
300
+ """
301
+ The rotary position embeddings from RoFormer_ (Su et. al).
302
+ A crucial insight from the method is that the query and keys are
303
+ transformed by rotation matrices which depend on the relative positions.
304
+
305
+ Other implementations are available in the Rotary Transformer repo_ and in
306
+ GPT-NeoX_, GPT-NeoX was an inspiration
307
+
308
+ .. _RoFormer: https://arxiv.org/abs/2104.09864
309
+ .. _repo: https://github.com/ZhuiyiTechnology/roformer
310
+ .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
311
+
312
+ If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
313
+ A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96
314
+ Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
315
+ """
316
+
317
+ def __init__(
318
+ self,
319
+ dim: int,
320
+ base=10000.0,
321
+ interleaved=False,
322
+ scale_base=None,
323
+ pos_idx_in_fp32=True,
324
+ device=None,
325
+ ):
326
+ """
327
+ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
328
+ of 1st half and 2nd half (GPT-NeoX style).
329
+ pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
330
+ otherwise they might be in lower precision.
331
+ This option was added because previously (before 2023-07-02), when we construct
332
+ the position indices, we use the dtype of self.inv_freq. In most cases this would
333
+ be fp32, but if the model is trained in pure bf16 (not mixed precision), then
334
+ self.inv_freq would be bf16, and the position indices are also in bf16.
335
+ Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
336
+ embeddings for some positions will coincide.
337
+ To maintain compatibility with models previously trained in pure bf16,
338
+ we add this option.
339
+ """
340
+ super().__init__()
341
+ self.dim = dim
342
+ self.base = float(base)
343
+ self.pos_idx_in_fp32 = pos_idx_in_fp32
344
+ # Generate and save the inverse frequency buffer (non trainable)
345
+ inv_freq = self._compute_inv_freq(device)
346
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
347
+ self.interleaved = interleaved
348
+ self.scale_base = scale_base
349
+ scale = (
350
+ (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
351
+ if scale_base is not None
352
+ else None
353
+ )
354
+ self.register_buffer("scale", scale, persistent=False)
355
+
356
+ self._seq_len_cached = 0
357
+ self._cos_cached = None
358
+ self._sin_cached = None
359
+ self._cos_k_cached = None
360
+ self._sin_k_cached = None
361
+
362
+ def _compute_inv_freq(self, device=None):
363
+ return 1.0 / (
364
+ self.base
365
+ ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)
366
+ )
367
+
368
+ def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
369
+ # Reset the tables if the sequence length has changed,
370
+ # if we're on a new device (possibly due to tracing for instance),
371
+ # or if we're switching from inference mode to training
372
+ if (
373
+ seqlen > self._seq_len_cached
374
+ or self._cos_cached is None
375
+ or self._cos_cached.device != device
376
+ or self._cos_cached.dtype != dtype
377
+ or (self.training and self._cos_cached.is_inference())
378
+ ):
379
+ self._seq_len_cached = seqlen
380
+ # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
381
+ # And the output of arange can be quite large, so bf16 would lose a lot of precision.
382
+ # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
383
+ if self.pos_idx_in_fp32:
384
+ t = torch.arange(seqlen, device=device, dtype=torch.float32)
385
+ # We want fp32 here as well since inv_freq will be multiplied with t, and the output
386
+ # will be large. Having it in bf16 will lose a lot of precision and cause the
387
+ # cos & sin output to change significantly.
388
+ # We want to recompute self.inv_freq if it was not loaded in fp32
389
+ if self.inv_freq.dtype != torch.float32:
390
+ inv_freq = self._compute_inv_freq(device=device)
391
+ else:
392
+ inv_freq = self.inv_freq
393
+ else:
394
+ t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
395
+ inv_freq = self.inv_freq
396
+ # Don't do einsum, it converts fp32 to fp16 under AMP
397
+ # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
398
+ freqs = torch.outer(t, inv_freq)
399
+ if self.scale is None:
400
+ self._cos_cached = torch.cos(freqs).to(dtype)
401
+ self._sin_cached = torch.sin(freqs).to(dtype)
402
+ else:
403
+ power = (
404
+ torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
405
+ - seqlen // 2
406
+ ) / self.scale_base
407
+ scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
408
+ # We want the multiplication by scale to happen in fp32
409
+ self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
410
+ self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
411
+ self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
412
+ self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
413
+
414
+ def forward(
415
+ self,
416
+ qkv: torch.Tensor,
417
+ kv: Optional[torch.Tensor] = None,
418
+ seqlen_offset: Union[int, torch.Tensor] = 0,
419
+ max_seqlen: Optional[int] = None,
420
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
421
+ """
422
+ qkv: (batch, seqlen, 3, nheads, headdim) if kv is none,
423
+ else it's just q of shape (batch, seqlen, nheads, headdim)
424
+ kv: (batch, seqlen, 2, nheads, headdim)
425
+ seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount.
426
+ Most commonly used in inference when we have KV cache.
427
+ If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one
428
+ should pass in max_seqlen, which will update the cos / sin cache up to that length.
429
+ Apply rotary embedding *inplace* to qkv and / or kv.
430
+ """
431
+ seqlen = qkv.shape[1]
432
+ if max_seqlen is not None:
433
+ self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
434
+ elif isinstance(seqlen_offset, int):
435
+ self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
436
+ if kv is None:
437
+ if self.scale is None:
438
+ return apply_rotary_emb_qkv_(
439
+ qkv,
440
+ self._cos_cached,
441
+ self._sin_cached,
442
+ interleaved=self.interleaved,
443
+ seqlen_offsets=seqlen_offset,
444
+ )
445
+ else:
446
+ return apply_rotary_emb_qkv_(
447
+ qkv,
448
+ self._cos_cached,
449
+ self._sin_cached,
450
+ self._cos_k_cached,
451
+ self._sin_k_cached,
452
+ interleaved=self.interleaved,
453
+ seqlen_offsets=seqlen_offset,
454
+ )
455
+ else:
456
+ q = qkv
457
+ q = apply_rotary_emb_func(
458
+ q,
459
+ self._cos_cached,
460
+ self._sin_cached,
461
+ interleaved=self.interleaved,
462
+ inplace=True,
463
+ seqlen_offsets=seqlen_offset,
464
+ )
465
+ if self.scale is None:
466
+ kv = apply_rotary_emb_kv_(
467
+ kv,
468
+ self._cos_cached,
469
+ self._sin_cached,
470
+ interleaved=self.interleaved,
471
+ seqlen_offsets=seqlen_offset,
472
+ )
473
+ else:
474
+ kv = apply_rotary_emb_kv_(
475
+ kv,
476
+ self._cos_k_cached,
477
+ self._sin_k_cached,
478
+ interleaved=self.interleaved,
479
+ seqlen_offsets=seqlen_offset,
480
+ )
481
+ return q, kv
.venv/lib/python3.11/site-packages/xformers/_flash_attn/losses/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/xformers/_flash_attn/losses/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (200 Bytes). View file
 
.venv/lib/python3.11/site-packages/xformers/_flash_attn/losses/__pycache__/cross_entropy.cpython-311.pyc ADDED
Binary file (3.89 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/_flash_attn/losses/cross_entropy.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao.
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from flash_attn.ops.triton.cross_entropy import cross_entropy_loss
7
+
8
+
9
+ class CrossEntropyLoss(nn.Module):
10
+ def __init__(
11
+ self,
12
+ ignore_index=-100,
13
+ reduction="mean",
14
+ label_smoothing=0.0,
15
+ logit_scale=1.0,
16
+ lse_square_scale=0.0,
17
+ inplace_backward=False,
18
+ process_group=None,
19
+ return_z_loss=False,
20
+ ):
21
+ """
22
+ Arguments:
23
+ ignore_index: int. If labels == ignore_index, the loss is set to 0.0.
24
+ label_smoothing: float
25
+ lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss.
26
+ This is also referred to as "z-loss".
27
+ inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits.
28
+ This saves memory.
29
+ process_group: if not None, we're doing Tensor Parallel: each process is responsible for
30
+ one part of the vocab. The loss will be aggregated across processes.
31
+ return_z_loss: bool. If True, we return the component of the loss contributed by
32
+ the lse_square_scale value. This value is only for logging and does not support
33
+ backprop.
34
+ """
35
+ super().__init__()
36
+ if reduction not in ["mean", "none", "sum"]:
37
+ raise NotImplementedError("Only support reduction = 'mean' or 'none' or 'sum'")
38
+ self.ignore_index = ignore_index
39
+ self.reduction = reduction
40
+ self.label_smoothing = label_smoothing
41
+ self.logit_scale = logit_scale
42
+ self.lse_square_scale = lse_square_scale
43
+ self.inplace_backward = inplace_backward
44
+ self.process_group = process_group
45
+ self.return_z_loss = return_z_loss
46
+
47
+ def forward(self, input, target, precomputed_lse=None):
48
+ """
49
+ Arguments:
50
+ input: (batch, vocab_size)
51
+ target: (batch,)
52
+ Returns:
53
+ losses: (batch,) if reduction is 'none', else (1,), dtype float
54
+ z_loss: (batch,) if reduction is 'none', else (1,), dtype float (if self.return_z_loss)
55
+ """
56
+ assert input.is_cuda and target.is_cuda, "Only support CUDA tensors"
57
+ loss, z_loss = cross_entropy_loss(
58
+ input,
59
+ target,
60
+ precomputed_lse=precomputed_lse,
61
+ label_smoothing=self.label_smoothing,
62
+ logit_scale=self.logit_scale,
63
+ lse_square_scale=self.lse_square_scale,
64
+ ignore_index=self.ignore_index,
65
+ inplace_backward=self.inplace_backward,
66
+ process_group=self.process_group,
67
+ )
68
+ if self.reduction == "mean":
69
+ loss = loss.sum() / (target != self.ignore_index).sum()
70
+ elif self.reduction == "sum":
71
+ loss = loss.sum()
72
+ else:
73
+ loss = loss
74
+
75
+ if not self.return_z_loss:
76
+ return loss
77
+
78
+ if self.reduction == "mean":
79
+ z_loss = z_loss.sum() / (target != self.ignore_index).sum()
80
+ elif self.reduction == "sum":
81
+ z_loss = z_loss.sum()
82
+ else:
83
+ z_loss = z_loss
84
+
85
+ return loss, z_loss
.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (200 Bytes). View file
 
.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/baichuan.cpython-311.pyc ADDED
Binary file (7.82 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/bert.cpython-311.pyc ADDED
Binary file (41.9 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/bigcode.cpython-311.pyc ADDED
Binary file (13.1 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/btlm.cpython-311.pyc ADDED
Binary file (7.7 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/falcon.cpython-311.pyc ADDED
Binary file (8.53 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/gpt.cpython-311.pyc ADDED
Binary file (54.6 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/gpt_neox.cpython-311.pyc ADDED
Binary file (7.87 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/gptj.cpython-311.pyc ADDED
Binary file (7.35 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/llama.cpython-311.pyc ADDED
Binary file (23.1 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/opt.cpython-311.pyc ADDED
Binary file (7.75 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/__pycache__/vit.cpython-311.pyc ADDED
Binary file (16.7 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/baichuan.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, GGGGGGXY, Tri Dao.
2
+
3
+ import math
4
+ import json
5
+ import re
6
+ from pathlib import Path
7
+
8
+ from collections import OrderedDict
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+
13
+ from einops import rearrange
14
+ from transformers import GPT2Config, AutoConfig, PretrainedConfig
15
+
16
+
17
+ def remap_state_dict_hf_baichuan(state_dict, config):
18
+ def key_mapping_layers(key):
19
+ return re.sub(r"^model.", "transformer.", key)
20
+
21
+ state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
22
+
23
+ # Word embedding
24
+ def key_mapping_emb(key):
25
+ return re.sub(
26
+ r"^transformer.embed_tokens.",
27
+ "transformer.embeddings.word_embeddings.",
28
+ key,
29
+ )
30
+
31
+ state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
32
+ word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight")
33
+ # It's possible that vocab_size is padded to be a multiple of 8, for example.
34
+ pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
35
+ vocab_size = (
36
+ math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple)
37
+ * pad_vocab_size_multiple
38
+ )
39
+ state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
40
+ word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
41
+ )
42
+ if getattr(config, "tie_word_embeddings"):
43
+ state_dict["lm_head.weight"] = state_dict[
44
+ "transformer.embeddings.word_embeddings.weight"
45
+ ]
46
+ else:
47
+ output_embeddings = state_dict.pop("lm_head.weight")
48
+ # Need to recompute vocab_size since Baichuan shards the word embeddings and output embeddings
49
+ # differently.
50
+ vocab_size = (
51
+ math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple)
52
+ * pad_vocab_size_multiple
53
+ )
54
+ # It's possible that vocab_size is padded to be a multiple of 8, for example.
55
+ state_dict["lm_head.weight"] = F.pad(
56
+ output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
57
+ )
58
+
59
+ # LayerNorm
60
+ def key_mapping_ln(key):
61
+ key = re.sub(r"^transformer.norm.", r"transformer.ln_f.", key)
62
+ key = re.sub(
63
+ r"^transformer.layers.(\d+).input_layernorm.",
64
+ r"transformer.layers.\1.norm1.",
65
+ key,
66
+ )
67
+ key = re.sub(
68
+ r"^transformer.layers.(\d+).post_attention_layernorm.",
69
+ r"transformer.layers.\1.norm2.",
70
+ key,
71
+ )
72
+ return key
73
+
74
+ state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
75
+
76
+ # MLP
77
+ for l in range(config.n_layer):
78
+ w1 = state_dict.pop(f"transformer.layers.{l}.mlp.gate_proj.weight")
79
+ w3 = state_dict.pop(f"transformer.layers.{l}.mlp.up_proj.weight")
80
+ # Our ordering is different
81
+ state_dict[f"transformer.layers.{l}.mlp.fc1.weight"] = torch.cat(
82
+ [w3, w1], dim=0
83
+ )
84
+
85
+ def key_mapping_mlp(key):
86
+ return re.sub(
87
+ r"^transformer.layers.(\d+).mlp.down_proj.",
88
+ r"transformer.layers.\1.mlp.fc2.",
89
+ key,
90
+ )
91
+
92
+ state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
93
+
94
+ # Attention
95
+ def key_mapping_attn(key):
96
+ key = re.sub(
97
+ r"^transformer.layers.(\d+).self_attn.W_pack.",
98
+ r"transformer.layers.\1.mixer.Wqkv.",
99
+ key,
100
+ )
101
+ key = re.sub(
102
+ r"^transformer.layers.(\d+).self_attn.o_proj.",
103
+ r"transformer.layers.\1.mixer.out_proj.",
104
+ key,
105
+ )
106
+ return key
107
+
108
+ state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
109
+ for l in range(config.n_layer):
110
+ # pop rotary_emb.inv_freq from state dict
111
+ state_dict.pop(f"transformer.layers.{l}.self_attn.rotary_emb.inv_freq", None)
112
+ return state_dict
113
+
114
+
115
+ def baichuan_config_to_gpt2_config(baichuan_config: PretrainedConfig) -> GPT2Config:
116
+ # HACK: the config doesn't have say whether it's rotary or alibi.
117
+ # So we have to infer from the hidden size (7B -> rotary, 13B -> alibi).
118
+ # HACK: the config doesn't have say whether it uses norm head.
119
+ # So we have to infer from the vocab size
120
+ # (v1, vocab size 64k, no norm head; v2, vocab size 128k, norm head).
121
+ use_rotary = baichuan_config.hidden_size < 5000
122
+ return GPT2Config(
123
+ vocab_size=baichuan_config.vocab_size,
124
+ n_positions=0, # No absolute position embedding
125
+ n_embd=baichuan_config.hidden_size,
126
+ n_layer=baichuan_config.num_hidden_layers,
127
+ n_head=baichuan_config.num_attention_heads,
128
+ n_inner=baichuan_config.intermediate_size,
129
+ activation_function="swiglu", # Hardcode since HF calls it 'silu'
130
+ # baichuan doesn't have dropout, idk if it's because they only release the inference code
131
+ resid_pdrop=0.0,
132
+ embd_pdrop=0.0,
133
+ attn_pdrop=0.0,
134
+ layer_norm_epsilon=baichuan_config.rms_norm_eps,
135
+ initializer_range=baichuan_config.initializer_range,
136
+ bos_token_id=baichuan_config.bos_token_id,
137
+ eos_token_id=baichuan_config.eos_token_id,
138
+ # These are new arguments not in the original GPT2Config
139
+ pad_token_id=baichuan_config.pad_token_id, # Idk if this does anything
140
+ rms_norm=True,
141
+ rotary_emb_fraction=1.0 if use_rotary else 0.0,
142
+ rotary_emb_interleaved=False,
143
+ use_alibi=not use_rotary,
144
+ use_flash_attn=not use_rotary, # Alibi code path requires flash_attn
145
+ tie_word_embeddings=False,
146
+ norm_head=baichuan_config.vocab_size > 70000,
147
+ qkv_proj_bias=False,
148
+ out_proj_bias=False,
149
+ mlp_fc1_bias=False,
150
+ mlp_fc2_bias=False,
151
+ )
.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/bert.py ADDED
@@ -0,0 +1,764 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, Tri Dao.
2
+ # This BERT implementation is based on our MLPerf 2.0 and MLPerf 2.1 BERT implementation.
3
+ # https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
4
+ # https://github.com/mlcommons/training_results_v2.1/blob/main/Azure-HazyResearch/benchmarks/bert/implementations/ND96amsr_A100_v4/modeling.py
5
+
6
+ # Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
7
+
8
+ import logging
9
+ import re
10
+ from collections import OrderedDict
11
+ from collections.abc import Sequence
12
+ from functools import partial
13
+ from typing import Any, Mapping
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from einops import rearrange
19
+ from transformers import BertConfig, PretrainedConfig
20
+ from transformers.models.bert.modeling_bert import (
21
+ BaseModelOutputWithPoolingAndCrossAttentions,
22
+ BertForPreTrainingOutput,
23
+ )
24
+
25
+ from flash_attn.bert_padding import (
26
+ index_first_axis,
27
+ index_first_axis_residual,
28
+ pad_input,
29
+ unpad_input,
30
+ )
31
+ from flash_attn.modules.block import Block
32
+ from flash_attn.modules.embedding import BertEmbeddings
33
+ from flash_attn.modules.mha import MHA
34
+ from flash_attn.modules.mlp import FusedMLP, Mlp
35
+ from flash_attn.utils.pretrained import state_dict_from_pretrained
36
+
37
+ try:
38
+ from flash_attn.ops.fused_dense import FusedDense
39
+ except ImportError:
40
+ FusedDense = None
41
+
42
+ try:
43
+ from flash_attn.ops.triton.layer_norm import layer_norm_fn
44
+ except ImportError:
45
+ layer_norm_fn = None
46
+
47
+
48
+ try:
49
+ from flash_attn.losses.cross_entropy import CrossEntropyLoss
50
+ except ImportError:
51
+ CrossEntropyLoss = None
52
+
53
+
54
+ logger = logging.getLogger(__name__)
55
+
56
+
57
+ def create_mixer_cls(config, cross_attn=False, return_residual=False):
58
+ use_flash_attn = getattr(config, "use_flash_attn", False)
59
+ fused_bias_fc = getattr(config, "fused_bias_fc", False)
60
+ rotary_kwargs = {}
61
+ if config.position_embedding_type == "rotary":
62
+ rotary_kwargs["rotary_emb_dim"] = getattr(config, "rotary_emb_dim", config.hidden_size)
63
+ rotary_kwargs["rotary_emb_base"] = getattr(config, "rotary_emb_base", 10000.0)
64
+ rotary_kwargs["rotary_emb_scale_base"] = getattr(config, "rotary_emb_scale_base", None)
65
+ rotary_kwargs["rotary_emb_interleaved"] = getattr(config, "rotary_emb_interleaved", False)
66
+ mixer_cls = partial(
67
+ MHA,
68
+ num_heads=config.num_attention_heads,
69
+ cross_attn=cross_attn,
70
+ dropout=config.attention_probs_dropout_prob,
71
+ causal=False,
72
+ fused_bias_fc=fused_bias_fc,
73
+ use_flash_attn=use_flash_attn,
74
+ return_residual=return_residual,
75
+ **rotary_kwargs,
76
+ )
77
+ return mixer_cls
78
+
79
+
80
+ def create_mlp_cls(config, layer_idx=None, return_residual=False):
81
+ inner_dim = config.intermediate_size
82
+ fused_mlp = getattr(config, "fused_mlp", False)
83
+ if fused_mlp:
84
+ assert config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"], (
85
+ "fused_mlp only " "supports approximate gelu"
86
+ )
87
+ if not fused_mlp:
88
+ approximate = (
89
+ "tanh"
90
+ if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
91
+ else "none"
92
+ )
93
+ mlp_cls = partial(
94
+ Mlp,
95
+ hidden_features=inner_dim,
96
+ activation=partial(F.gelu, approximate=approximate),
97
+ return_residual=return_residual,
98
+ )
99
+ else:
100
+ if FusedMLP is None:
101
+ raise ImportError("fused_dense is not installed")
102
+ mlp_checkpoint_lvl = getattr(config, "mlp_checkpoint_lvl", 0)
103
+ # mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
104
+ if isinstance(mlp_checkpoint_lvl, Sequence):
105
+ assert layer_idx is not None
106
+ mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
107
+ mlp_cls = partial(
108
+ FusedMLP,
109
+ hidden_features=inner_dim,
110
+ checkpoint_lvl=mlp_checkpoint_lvl,
111
+ return_residual=return_residual,
112
+ )
113
+ return mlp_cls
114
+
115
+
116
+ def create_block(config, layer_idx=None):
117
+ last_layer_subset = getattr(config, "last_layer_subset", False)
118
+ cross_attn = last_layer_subset and layer_idx == config.num_hidden_layers - 1
119
+ # TD [2022-12-19]: For cross attention (last layer), we actually want to return the
120
+ # residual x_kv, not residual x. But it's annoying to change the API (and it only affects
121
+ # one layer) so we just choose not to return residual in this case.
122
+ return_residual = not cross_attn
123
+ mixer_cls = create_mixer_cls(config, cross_attn, return_residual=return_residual)
124
+ mlp_cls = create_mlp_cls(config, layer_idx, return_residual=return_residual)
125
+ norm_cls = partial(nn.LayerNorm, eps=config.layer_norm_eps)
126
+ block = Block(
127
+ config.hidden_size,
128
+ mixer_cls,
129
+ mlp_cls,
130
+ norm_cls=norm_cls,
131
+ prenorm=False,
132
+ resid_dropout1=config.hidden_dropout_prob,
133
+ resid_dropout2=config.hidden_dropout_prob,
134
+ fused_dropout_add_ln=getattr(config, "fused_dropout_add_ln", False),
135
+ return_residual=return_residual,
136
+ )
137
+ return block
138
+
139
+
140
+ # https://github.com/huggingface/transformers/blob/7032e0203262ebb2ebf55da8d2e01f873973e835/src/transformers/models/bert/modeling_bert.py#L748
141
+ def _init_weights(module, initializer_range=0.02):
142
+ if isinstance(module, nn.Linear):
143
+ nn.init.normal_(module.weight, std=initializer_range)
144
+ if module.bias is not None:
145
+ nn.init.zeros_(module.bias)
146
+ elif isinstance(module, nn.Embedding):
147
+ nn.init.normal_(module.weight, std=initializer_range)
148
+ if module.padding_idx is not None:
149
+ nn.init.zeros_(module.weight[module.padding_idx])
150
+
151
+
152
+ class BertEncoder(nn.Module):
153
+ def __init__(self, config: BertConfig):
154
+ super().__init__()
155
+ self.use_flash_attn = getattr(config, "use_flash_attn", False)
156
+ self.layers = nn.ModuleList(
157
+ [create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
158
+ )
159
+
160
+ def forward(self, hidden_states, key_padding_mask=None, subset_mask=None):
161
+ """If subset_mask is not None, we only want output for the subset of the sequence.
162
+ This means that we only compute the last layer output for these tokens.
163
+ subset_mask: (batch, seqlen), dtype=torch.bool
164
+ """
165
+ if key_padding_mask is None or not self.use_flash_attn:
166
+ mixer_kwargs = (
167
+ {"key_padding_mask": key_padding_mask} if key_padding_mask is not None else None
168
+ )
169
+ for layer in self.layers:
170
+ hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
171
+ if subset_mask is not None:
172
+ hidden_states = hidden_states[subset_mask]
173
+ else:
174
+ batch, seqlen = hidden_states.shape[:2]
175
+ hidden_states, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
176
+ hidden_states, key_padding_mask
177
+ )
178
+ mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
179
+ if subset_mask is None:
180
+ for layer in self.layers:
181
+ hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
182
+ hidden_states = pad_input(hidden_states, indices, batch, seqlen)
183
+ else:
184
+ for layer in self.layers[:-1]:
185
+ hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
186
+ if key_padding_mask is not None:
187
+ subset_idx = torch.nonzero(
188
+ subset_mask[key_padding_mask], as_tuple=False
189
+ ).flatten()
190
+ subset_seqlens = (subset_mask & key_padding_mask).sum(dim=-1, dtype=torch.int32)
191
+ subset_cu_seqlens = F.pad(
192
+ torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32), (1, 0)
193
+ )
194
+ else:
195
+ subset_idx = torch.nonzero(subset_mask, as_tuple=False).flatten()
196
+ subset_seqlens = subset_mask.sum(dim=-1, dtype=torch.int32)
197
+ subset_cu_seqlens = F.pad(
198
+ torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32), (1, 0)
199
+ )
200
+ hidden_states_subset, hidden_states = index_first_axis_residual(
201
+ hidden_states, subset_idx
202
+ )
203
+ # It's ok to set max_seqlen_q to be much larger
204
+ mixer_kwargs = {
205
+ "x_kv": hidden_states,
206
+ "cu_seqlens": subset_cu_seqlens,
207
+ "max_seqlen": max_seqlen_in_batch,
208
+ "cu_seqlens_k": cu_seqlens,
209
+ "max_seqlen_k": max_seqlen_in_batch,
210
+ }
211
+ hidden_states = self.layers[-1](hidden_states_subset, mixer_kwargs=mixer_kwargs)
212
+ return hidden_states
213
+
214
+
215
+ class BertPooler(nn.Module):
216
+ def __init__(self, config):
217
+ super().__init__()
218
+ fused_bias_fc = getattr(config, "fused_bias_fc", False)
219
+ if fused_bias_fc and FusedDense is None:
220
+ raise ImportError("fused_dense is not installed")
221
+ linear_cls = nn.Linear if not fused_bias_fc else FusedDense
222
+ self.dense = linear_cls(config.hidden_size, config.hidden_size)
223
+ self.activation = nn.Tanh()
224
+
225
+ def forward(self, hidden_states, pool=True):
226
+ # We "pool" the model by simply taking the hidden state corresponding
227
+ # to the first token.
228
+ first_token_tensor = hidden_states[:, 0] if pool else hidden_states
229
+ pooled_output = self.dense(first_token_tensor)
230
+ pooled_output = self.activation(pooled_output)
231
+ return pooled_output
232
+
233
+
234
+ class BertPredictionHeadTransform(nn.Module):
235
+ def __init__(self, config):
236
+ super().__init__()
237
+ fused_bias_fc = getattr(config, "fused_bias_fc", False)
238
+ if fused_bias_fc and FusedDense is None:
239
+ raise ImportError("fused_dense is not installed")
240
+ self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
241
+ if self.fused_dropout_add_ln and layer_norm_fn is None:
242
+ raise ImportError("Triton is not installed")
243
+ linear_cls = nn.Linear if not fused_bias_fc else FusedDense
244
+ self.dense = linear_cls(config.hidden_size, config.hidden_size)
245
+ approximate = (
246
+ "tanh"
247
+ if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
248
+ else "none"
249
+ )
250
+ self.transform_act_fn = nn.GELU(approximate=approximate)
251
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
252
+
253
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
254
+ hidden_states = self.dense(hidden_states)
255
+ hidden_states = self.transform_act_fn(hidden_states)
256
+ if not self.fused_dropout_add_ln:
257
+ hidden_states = self.layer_norm(hidden_states)
258
+ else:
259
+ hidden_states = layer_norm_fn(
260
+ hidden_states, self.layer_norm.weight, self.layer_norm.bias, eps=self.layer_norm.eps
261
+ )
262
+ return hidden_states
263
+
264
+
265
+ class BertLMPredictionHead(nn.Module):
266
+ def __init__(self, config):
267
+ super().__init__()
268
+ fused_bias_fc = getattr(config, "fused_bias_fc", False)
269
+ if fused_bias_fc and FusedDense is None:
270
+ raise ImportError("fused_dense is not installed")
271
+ linear_cls = nn.Linear if not fused_bias_fc else FusedDense
272
+
273
+ self.transform = BertPredictionHeadTransform(config)
274
+
275
+ # The output weights are the same as the input embeddings, but there is
276
+ # an output-only bias for each token.
277
+ self.decoder = linear_cls(config.hidden_size, config.vocab_size, bias=True)
278
+
279
+ def forward(self, hidden_states):
280
+ hidden_states = self.transform(hidden_states)
281
+ hidden_states = self.decoder(hidden_states)
282
+ return hidden_states
283
+
284
+
285
+ class BertPreTrainingHeads(nn.Module):
286
+ def __init__(self, config):
287
+ super().__init__()
288
+ self.predictions = BertLMPredictionHead(config)
289
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
290
+
291
+ def forward(self, sequence_output, pooled_output):
292
+ prediction_scores = self.predictions(sequence_output)
293
+ seq_relationship_score = self.seq_relationship(pooled_output)
294
+ return prediction_scores, seq_relationship_score
295
+
296
+
297
+ class BertPreTrainedModel(nn.Module):
298
+ """An abstract class to handle weights initialization and
299
+ a simple interface for dowloading and loading pretrained models.
300
+ """
301
+
302
+ def __init__(self, config, *inputs, **kwargs):
303
+ super().__init__()
304
+ if not isinstance(config, BertConfig):
305
+ raise ValueError(
306
+ "Parameter config in `{}(config)` should be an instance of class `BertConfig`. "
307
+ "To create a model from a Google pretrained model use "
308
+ "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
309
+ self.__class__.__name__, self.__class__.__name__
310
+ )
311
+ )
312
+ self.config = config
313
+
314
+ @classmethod
315
+ def from_pretrained(cls, model_name, config, *inputs, **kwargs):
316
+ """
317
+ Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict.
318
+ Download and cache the pre-trained model file if needed.
319
+
320
+ Params:
321
+ pretrained_model_name_or_path: either:
322
+ - a path or url to a pretrained model archive containing:
323
+ . `bert_config.json` a configuration file for the model
324
+ . `pytorch_model.bin` a PyTorch dump of a BertForPretraining instance
325
+ - a path or url to a pretrained model archive containing:
326
+ . `bert_config.json` a configuration file for the model
327
+ . `model.chkpt` a TensorFlow checkpoint
328
+ *inputs, **kwargs: additional input for the specific Bert class
329
+ (ex: num_labels for BertForSequenceClassification)
330
+ """
331
+ # Instantiate model.
332
+ model = cls(config, *inputs, **kwargs)
333
+ load_return = model.load_state_dict(
334
+ remap_state_dict(state_dict_from_pretrained(model_name), config), strict=False
335
+ )
336
+ logger.info(load_return)
337
+ return model
338
+
339
+
340
+ class BertModel(BertPreTrainedModel):
341
+ def __init__(self, config: BertConfig, add_pooling_layer=True):
342
+ super().__init__(config)
343
+ self.pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
344
+ if config.vocab_size % self.pad_vocab_size_multiple != 0:
345
+ config.vocab_size += self.pad_vocab_size_multiple - (
346
+ config.vocab_size % self.pad_vocab_size_multiple
347
+ )
348
+ self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
349
+ if self.fused_dropout_add_ln and layer_norm_fn is None:
350
+ raise ImportError("Triton is not installed")
351
+ assert config.hidden_act in ["gelu", "gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
352
+
353
+ self.embeddings = BertEmbeddings(
354
+ config.hidden_size,
355
+ config.vocab_size,
356
+ config.max_position_embeddings,
357
+ config.type_vocab_size,
358
+ padding_idx=config.pad_token_id,
359
+ )
360
+ self.emb_drop = nn.Dropout(config.hidden_dropout_prob)
361
+ self.emb_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
362
+ self.encoder = BertEncoder(config)
363
+ self.pooler = BertPooler(config) if add_pooling_layer else None
364
+
365
+ self.apply(partial(_init_weights, initializer_range=config.initializer_range))
366
+
367
+ def forward(
368
+ self,
369
+ input_ids,
370
+ position_ids=None,
371
+ token_type_ids=None,
372
+ attention_mask=None,
373
+ masked_tokens_mask=None,
374
+ ):
375
+ """If masked_tokens_mask is not None (i.e. last_layer_subset == True in BertForPreTraining),
376
+ we only want the output for the masked tokens. This means that we only compute the last
377
+ layer output for these tokens.
378
+ masked_tokens_mask: (batch, seqlen), dtype=torch.bool
379
+ """
380
+ hidden_states = self.embeddings(
381
+ input_ids, position_ids=position_ids, token_type_ids=token_type_ids
382
+ )
383
+ # TD [2022-12:18]: Don't need to force residual in fp32
384
+ # BERT puts embedding LayerNorm before embedding dropout.
385
+ if not self.fused_dropout_add_ln:
386
+ hidden_states = self.emb_ln(hidden_states)
387
+ else:
388
+ hidden_states = layer_norm_fn(
389
+ hidden_states, self.emb_ln.weight, self.emb_ln.bias, eps=self.emb_ln.eps
390
+ )
391
+ hidden_states = self.emb_drop(hidden_states)
392
+
393
+ if masked_tokens_mask is not None:
394
+ batch_size, seqlen = input_ids.shape[:2]
395
+ # We also need the first column for the CLS token
396
+ first_col_mask = torch.zeros(
397
+ batch_size, seqlen, dtype=torch.bool, device=input_ids.device
398
+ )
399
+ first_col_mask[:, 0] = True
400
+ subset_mask = masked_tokens_mask | first_col_mask
401
+ else:
402
+ subset_mask = None
403
+
404
+ sequence_output = self.encoder(
405
+ hidden_states, key_padding_mask=attention_mask, subset_mask=subset_mask
406
+ )
407
+
408
+ if masked_tokens_mask is None:
409
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
410
+ else:
411
+ # TD [2022-03-01]: the indexing here is very tricky.
412
+ if attention_mask is not None:
413
+ subset_idx = subset_mask[attention_mask]
414
+ pool_input = sequence_output[first_col_mask[attention_mask][subset_idx]]
415
+ sequence_output = sequence_output[masked_tokens_mask[attention_mask][subset_idx]]
416
+ else:
417
+ pool_input = sequence_output[first_col_mask[subset_mask]]
418
+ sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
419
+ pooled_output = self.pooler(pool_input, pool=False) if self.pooler is not None else None
420
+
421
+ return BaseModelOutputWithPoolingAndCrossAttentions(
422
+ last_hidden_state=sequence_output,
423
+ pooler_output=pooled_output,
424
+ )
425
+
426
+
427
+ class BertForPreTraining(BertPreTrainedModel):
428
+ def __init__(self, config: BertConfig):
429
+ super().__init__(config)
430
+ # If dense_seq_output, we only need to pass the hidden states for the masked out tokens
431
+ # (around 15%) to the classifier heads.
432
+ self.dense_seq_output = getattr(config, "dense_seq_output", False)
433
+ # If last_layer_subset, we only need the compute the last layer for a subset of tokens
434
+ # (e.g., the tokens we need to compute the masked LM loss and the next-sentence prediction).
435
+ self.last_layer_subset = getattr(config, "last_layer_subset", False)
436
+ if self.last_layer_subset:
437
+ assert self.dense_seq_output, "last_layer_subset requires dense_seq_output"
438
+ use_xentropy = getattr(config, "use_xentropy", False)
439
+ if use_xentropy and CrossEntropyLoss is None:
440
+ raise ImportError("xentropy_cuda is not installed")
441
+ loss_cls = (
442
+ nn.CrossEntropyLoss
443
+ if not use_xentropy
444
+ else partial(CrossEntropyLoss, inplace_backward=True)
445
+ )
446
+
447
+ self.bert = BertModel(config)
448
+ self.cls = BertPreTrainingHeads(config)
449
+ self.mlm_loss = loss_cls(ignore_index=0)
450
+ self.nsp_loss = loss_cls(ignore_index=-1)
451
+
452
+ # Initialize weights and apply final processing
453
+ self.apply(partial(_init_weights, initializer_range=config.initializer_range))
454
+ self.tie_weights()
455
+
456
+ def tie_weights(self):
457
+ self.cls.predictions.decoder.weight = self.bert.embeddings.word_embeddings.weight
458
+
459
+ def forward(
460
+ self,
461
+ input_ids,
462
+ position_ids=None,
463
+ token_type_ids=None,
464
+ attention_mask=None,
465
+ labels=None,
466
+ next_sentence_label=None,
467
+ ):
468
+ """
469
+ If labels are provided, they must be 0 for masked out tokens (as specified in the attention
470
+ mask).
471
+ Outputs:
472
+ if `labels` and `next_sentence_label` are not `None`:
473
+ Outputs the total_loss which is the sum of the masked language modeling loss and the next
474
+ sentence classification loss.
475
+ if `labels` or `next_sentence_label` is `None`:
476
+ Outputs a tuple comprising
477
+ - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and
478
+ - the next sentence classification logits of shape [batch_size, 2].
479
+
480
+ """
481
+ masked_tokens_mask = labels > 0 if (self.last_layer_subset and labels is not None) else None
482
+ outputs = self.bert(
483
+ input_ids,
484
+ position_ids=position_ids,
485
+ token_type_ids=token_type_ids,
486
+ attention_mask=attention_mask.bool() if attention_mask is not None else None,
487
+ masked_tokens_mask=masked_tokens_mask,
488
+ )
489
+ sequence_output, pooled_output = outputs.last_hidden_state, outputs.pooler_output
490
+ if self.dense_seq_output and labels is not None:
491
+ masked_token_idx = torch.nonzero(labels.flatten() > 0, as_tuple=False).flatten()
492
+ if not self.last_layer_subset:
493
+ sequence_output = index_first_axis(
494
+ rearrange(sequence_output, "b s d -> (b s) d"), masked_token_idx
495
+ )
496
+ prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
497
+
498
+ total_loss = None
499
+ if labels is not None and next_sentence_label is not None:
500
+ if (
501
+ self.dense_seq_output and labels is not None
502
+ ): # prediction_scores are already flattened
503
+ masked_lm_loss = self.mlm_loss(
504
+ prediction_scores, labels.flatten()[masked_token_idx]
505
+ )
506
+ else:
507
+ masked_lm_loss = self.mlm_loss(
508
+ rearrange(prediction_scores, "... v -> (...) v"),
509
+ rearrange(labels, "... -> (...)"),
510
+ )
511
+ next_sentence_loss = self.nsp_loss(
512
+ rearrange(seq_relationship_score, "... t -> (...) t"),
513
+ rearrange(next_sentence_label, "... -> (...)"),
514
+ )
515
+ total_loss = masked_lm_loss.float() + next_sentence_loss.float()
516
+
517
+ return BertForPreTrainingOutput(
518
+ loss=total_loss,
519
+ prediction_logits=prediction_scores,
520
+ seq_relationship_logits=seq_relationship_score,
521
+ )
522
+
523
+
524
+ def remap_state_dict(state_dict, config: PretrainedConfig):
525
+ """
526
+ Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
527
+ """
528
+
529
+ # LayerNorm
530
+ def key_mapping_ln_gamma_beta(key):
531
+ key = re.sub(r"LayerNorm.gamma$", "LayerNorm.weight", key)
532
+ key = re.sub(r"LayerNorm.beta$", "LayerNorm.bias", key)
533
+ return key
534
+
535
+ state_dict = OrderedDict((key_mapping_ln_gamma_beta(k), v) for k, v in state_dict.items())
536
+
537
+ # Layers
538
+ def key_mapping_layers(key):
539
+ return re.sub(r"^bert.encoder.layer.", "bert.encoder.layers.", key)
540
+
541
+ state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
542
+
543
+ # LayerNorm
544
+ def key_mapping_ln(key):
545
+ key = re.sub(r"^bert.embeddings.LayerNorm.", "bert.emb_ln.", key)
546
+ key = re.sub(
547
+ r"^bert.encoder.layers.(\d+).attention.output.LayerNorm.(weight|bias)",
548
+ r"bert.encoder.layers.\1.norm1.\2",
549
+ key,
550
+ )
551
+ key = re.sub(
552
+ r"^bert.encoder.layers.(\d+).output.LayerNorm.(weight|bias)",
553
+ r"bert.encoder.layers.\1.norm2.\2",
554
+ key,
555
+ )
556
+ key = re.sub(
557
+ r"^cls.predictions.transform.LayerNorm.(weight|bias)",
558
+ r"cls.predictions.transform.layer_norm.\1",
559
+ key,
560
+ )
561
+ return key
562
+
563
+ state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
564
+
565
+ # MLP
566
+ def key_mapping_mlp(key):
567
+ key = re.sub(
568
+ r"^bert.encoder.layers.(\d+).intermediate.dense.(weight|bias)",
569
+ r"bert.encoder.layers.\1.mlp.fc1.\2",
570
+ key,
571
+ )
572
+ key = re.sub(
573
+ r"^bert.encoder.layers.(\d+).output.dense.(weight|bias)",
574
+ r"bert.encoder.layers.\1.mlp.fc2.\2",
575
+ key,
576
+ )
577
+ return key
578
+
579
+ state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
580
+
581
+ # Attention
582
+ last_layer_subset = getattr(config, "last_layer_subset", False)
583
+ for d in range(config.num_hidden_layers):
584
+ Wq = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.query.weight")
585
+ Wk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.weight")
586
+ Wv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.weight")
587
+ bq = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.query.bias")
588
+ bk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.bias")
589
+ bv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.bias")
590
+ if not (last_layer_subset and d == config.num_hidden_layers - 1):
591
+ state_dict[f"bert.encoder.layers.{d}.mixer.Wqkv.weight"] = torch.cat(
592
+ [Wq, Wk, Wv], dim=0
593
+ )
594
+ state_dict[f"bert.encoder.layers.{d}.mixer.Wqkv.bias"] = torch.cat([bq, bk, bv], dim=0)
595
+ else:
596
+ state_dict[f"bert.encoder.layers.{d}.mixer.Wq.weight"] = Wq
597
+ state_dict[f"bert.encoder.layers.{d}.mixer.Wkv.weight"] = torch.cat([Wk, Wv], dim=0)
598
+ state_dict[f"bert.encoder.layers.{d}.mixer.Wq.bias"] = bq
599
+ state_dict[f"bert.encoder.layers.{d}.mixer.Wkv.bias"] = torch.cat([bk, bv], dim=0)
600
+
601
+ def key_mapping_attn(key):
602
+ return re.sub(
603
+ r"^bert.encoder.layers.(\d+).attention.output.dense.(weight|bias)",
604
+ r"bert.encoder.layers.\1.mixer.out_proj.\2",
605
+ key,
606
+ )
607
+
608
+ state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
609
+
610
+ def key_mapping_decoder_bias(key):
611
+ return re.sub(r"^cls.predictions.bias", "cls.predictions.decoder.bias", key)
612
+
613
+ state_dict = OrderedDict((key_mapping_decoder_bias(k), v) for k, v in state_dict.items())
614
+
615
+ # Word embedding
616
+ pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
617
+ if pad_vocab_size_multiple > 1:
618
+ word_embeddings = state_dict["bert.embeddings.word_embeddings.weight"]
619
+ state_dict["bert.embeddings.word_embeddings.weight"] = F.pad(
620
+ word_embeddings, (0, 0, 0, config.vocab_size - word_embeddings.shape[0])
621
+ )
622
+ decoder_weight = state_dict["cls.predictions.decoder.weight"]
623
+ state_dict["cls.predictions.decoder.weight"] = F.pad(
624
+ decoder_weight, (0, 0, 0, config.vocab_size - decoder_weight.shape[0])
625
+ )
626
+ # If the vocab was padded, we want to set the decoder bias for those padded indices to be
627
+ # strongly negative (i.e. the decoder shouldn't predict those indices).
628
+ # TD [2022-05-09]: I don't think it affects the MLPerf training.
629
+ decoder_bias = state_dict["cls.predictions.decoder.bias"]
630
+ state_dict["cls.predictions.decoder.bias"] = F.pad(
631
+ decoder_bias, (0, config.vocab_size - decoder_bias.shape[0]), value=-100.0
632
+ )
633
+
634
+ return state_dict
635
+
636
+
637
+ def inv_remap_state_dict(state_dict, config: PretrainedConfig):
638
+ """
639
+ Map the state_dict of a flash_attn model to be Huggingface BERT compatible.
640
+
641
+ This function is meant to be the inverse of remap_state_dict.
642
+ """
643
+ # Word embedding
644
+ pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
645
+ if pad_vocab_size_multiple > 1:
646
+ word_embeddings = state_dict["bert.embeddings.word_embeddings.weight"]
647
+ decoder_weight = state_dict["cls.predictions.decoder.weight"]
648
+ decoder_bias = state_dict["cls.predictions.decoder.bias"]
649
+ # unpad embeddings
650
+ state_dict["bert.embeddings.word_embeddings.weight"] = word_embeddings[
651
+ : config.orig_vocab_size, :
652
+ ]
653
+ state_dict["cls.predictions.decoder.weight"] = decoder_weight[: config.orig_vocab_size, :]
654
+ state_dict["cls.predictions.decoder.bias"] = decoder_bias[: config.orig_vocab_size]
655
+
656
+ for d in range(config.num_hidden_layers):
657
+ last_layer_subset = getattr(config, "last_layer_subset", False)
658
+ if not last_layer_subset or d != (config.num_hidden_layers - 1):
659
+ Wqkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.weight")
660
+ Wqkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.bias")
661
+ state_dict[f"bert.encoder.layers.{d}.attention.self.query.weight"] = Wqkv_weights[
662
+ : Wqkv_weights.shape[0] // 3, :
663
+ ]
664
+ state_dict[f"bert.encoder.layers.{d}.attention.self.key.weight"] = Wqkv_weights[
665
+ Wqkv_weights.shape[0] // 3 : 2 * Wqkv_weights.shape[0] // 3, :
666
+ ]
667
+ state_dict[f"bert.encoder.layers.{d}.attention.self.value.weight"] = Wqkv_weights[
668
+ 2 * Wqkv_weights.shape[0] // 3 :, :
669
+ ]
670
+ state_dict[f"bert.encoder.layers.{d}.attention.self.query.bias"] = Wqkv_biases[
671
+ : Wqkv_biases.shape[0] // 3
672
+ ]
673
+ state_dict[f"bert.encoder.layers.{d}.attention.self.key.bias"] = Wqkv_biases[
674
+ Wqkv_biases.shape[0] // 3 : 2 * Wqkv_biases.shape[0] // 3
675
+ ]
676
+ state_dict[f"bert.encoder.layers.{d}.attention.self.value.bias"] = Wqkv_biases[
677
+ 2 * Wqkv_biases.shape[0] // 3 :
678
+ ]
679
+ else:
680
+ Wq_weight = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.weight")
681
+ Wkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.weight")
682
+ Wq_bias = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.bias")
683
+ Wkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.bias")
684
+ state_dict[f"bert.encoder.layers.{d}.attention.self.query.weight"] = Wq_weight
685
+ state_dict[f"bert.encoder.layers.{d}.attention.self.key.weight"] = Wkv_weights[
686
+ : Wkv_weights.shape[0] // 2, :
687
+ ]
688
+ state_dict[f"bert.encoder.layers.{d}.attention.self.value.weight"] = Wkv_weights[
689
+ Wkv_weights.shape[0] // 2 :, :
690
+ ]
691
+ state_dict[f"bert.encoder.layers.{d}.attention.self.query.bias"] = Wq_bias
692
+ state_dict[f"bert.encoder.layers.{d}.attention.self.key.bias"] = Wkv_biases[
693
+ : Wkv_biases.shape[0] // 2
694
+ ]
695
+ state_dict[f"bert.encoder.layers.{d}.attention.self.value.bias"] = Wkv_biases[
696
+ Wkv_biases.shape[0] // 2 :
697
+ ]
698
+
699
+ def inv_key_mapping_ln(key):
700
+ key = re.sub(r"bert.emb_ln.", "bert.embeddings.LayerNorm.", key)
701
+ key = re.sub(
702
+ r"bert.encoder.layers.(\d+).norm1.(weight|bias)",
703
+ r"bert.encoder.layers.\1.attention.output.LayerNorm.\2",
704
+ key,
705
+ )
706
+ key = re.sub(
707
+ r"bert.encoder.layers.(\d+).norm2.(weight|bias)",
708
+ r"bert.encoder.layers.\1.output.LayerNorm.\2",
709
+ key,
710
+ )
711
+ key = re.sub(
712
+ r"cls.predictions.transform.layer_norm.(weight|bias)",
713
+ r"cls.predictions.transform.LayerNorm.\1",
714
+ key,
715
+ )
716
+ return key
717
+
718
+ def inv_key_mapping_ln_gamma_beta(key):
719
+ key = re.sub(r"LayerNorm.weight$", "LayerNorm.gamma", key)
720
+ key = re.sub(r"LayerNorm.bias$", "LayerNorm.beta", key)
721
+ return key
722
+
723
+ def inv_key_mapping_layers(key):
724
+ return re.sub(r"bert.encoder.layers.", "bert.encoder.layer.", key)
725
+
726
+ def inv_key_mapping_mlp(key):
727
+ key = re.sub(
728
+ r"bert.encoder.layer.(\d+).mlp.fc1.(weight|bias)",
729
+ r"bert.encoder.layer.\1.intermediate.dense.\2",
730
+ key,
731
+ )
732
+ key = re.sub(
733
+ r"bert.encoder.layer.(\d+).mlp.fc2.(weight|bias)",
734
+ r"bert.encoder.layer.\1.output.dense.\2",
735
+ key,
736
+ )
737
+ return key
738
+
739
+ def inv_key_mapping_attn(key):
740
+ return re.sub(
741
+ r"bert.encoder.layer.(\d+).mixer.out_proj.(weight|bias)",
742
+ r"bert.encoder.layer.\1.attention.output.dense.\2",
743
+ key,
744
+ )
745
+
746
+ def inv_key_mapping_decoder_bias(key):
747
+ return re.sub(r"cls.predictions.decoder.bias", "cls.predictions.bias", key)
748
+
749
+ state_dict = OrderedDict((inv_key_mapping_ln(key), value) for key, value in state_dict.items())
750
+ state_dict = OrderedDict(
751
+ (inv_key_mapping_ln_gamma_beta(key), value) for key, value in state_dict.items()
752
+ )
753
+ state_dict = OrderedDict(
754
+ (inv_key_mapping_layers(key), value) for key, value in state_dict.items()
755
+ )
756
+ state_dict = OrderedDict((inv_key_mapping_mlp(key), value) for key, value in state_dict.items())
757
+ state_dict = OrderedDict(
758
+ (inv_key_mapping_attn(key), value) for key, value in state_dict.items()
759
+ )
760
+ state_dict = OrderedDict(
761
+ (inv_key_mapping_decoder_bias(key), value) for key, value in state_dict.items()
762
+ )
763
+
764
+ return state_dict
.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/bigcode.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import re
3
+ from collections import OrderedDict
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from transformers import GPT2Config, GPTBigCodeConfig, PretrainedConfig
8
+
9
+
10
+ def remap_state_dict_hf_bigcode(state_dict, config: PretrainedConfig):
11
+ """
12
+ Map the state_dict of a Huggingface BigCode model to be flash_attn compatible.
13
+ """
14
+
15
+ # Word embedding and position embedding
16
+ def key_mapping_pos_emb(key):
17
+ return re.sub(r"^transformer.wpe.", "transformer.embeddings.position_embeddings.", key)
18
+
19
+ state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items())
20
+ word_embeddings = state_dict.pop("transformer.wte.weight")
21
+ # It's possible that vocab_size is padded to be a multiple of 8, for example.
22
+ pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
23
+ vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
24
+ state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
25
+ word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
26
+ )
27
+ state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
28
+
29
+ # LayerNorm
30
+ def key_mapping_ln(key):
31
+ key = re.sub(r"^transformer.ln_f.(weight|bias)", r"transformer.ln_f.\1", key)
32
+ key = re.sub(
33
+ r"^transformer.h.(\d+).ln_(1|2).(weight|bias)",
34
+ r"transformer.layers.\1.norm\2.\3",
35
+ key,
36
+ )
37
+ return key
38
+
39
+ state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
40
+
41
+ def key_mapping_mlp(key):
42
+ key = re.sub(
43
+ r"^transformer.h.(\d+).mlp.c_fc.weight",
44
+ r"transformer.layers.\1.mlp.fc1.weight",
45
+ key,
46
+ )
47
+ key = re.sub(
48
+ r"^transformer.h.(\d+).mlp.c_proj.weight",
49
+ r"transformer.layers.\1.mlp.fc2.weight",
50
+ key,
51
+ )
52
+ key = re.sub(
53
+ r"^transformer.h.(\d+).mlp.c_fc.bias",
54
+ r"transformer.layers.\1.mlp.fc1.bias",
55
+ key,
56
+ )
57
+ key = re.sub(
58
+ r"^transformer.h.(\d+).mlp.c_proj.bias",
59
+ r"transformer.layers.\1.mlp.fc2.bias",
60
+ key,
61
+ )
62
+ return key
63
+
64
+ state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
65
+
66
+ # TODO: add support for multi-head attention
67
+ assert config.multi_query, "Only multi-query attention is supported"
68
+
69
+ # Attention
70
+ for d in range(config.num_hidden_layers):
71
+ embed_dim = config.n_embd
72
+ head_dim = embed_dim // config.n_head
73
+
74
+ c_attn_weight = state_dict.pop(f"transformer.h.{d}.attn.c_attn.weight")
75
+ # with multi-query attention, the weights have shape (embed_dim, embed_dim + head_dim + head_dim)
76
+ # see https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py#L112
77
+ # see also https://github.com/ggerganov/ggml/blob/dd1d575956e54c5bdc07632f25506b3b1884dbd2/examples/starcoder/convert-hf-to-ggml.py#L183
78
+ # ((n_head + 2) * head_dim, embed_dim) -> (3 * n_heads * head_dim, hidden_dim)
79
+ q, k, v = torch.split(c_attn_weight, [embed_dim, head_dim, head_dim], dim=0)
80
+ # duplicate k, v along the first axis (head_dim, hidden_dim) -> (n_heads * head_dim, hidden_dim)
81
+ k = torch.tile(k, (config.n_head, 1))
82
+ v = torch.tile(v, (config.n_head, 1))
83
+ state_dict[f"transformer.layers.{d}.mixer.Wqkv.weight"] = torch.cat((q, k, v), dim=0)
84
+
85
+ # same deal with the bias
86
+ c_attn_bias = state_dict.pop(f"transformer.h.{d}.attn.c_attn.bias")
87
+ # ((n_head + 2) * head_dim, embed_dim) -> (3 * n_heads * head_dim, hidden_dim)
88
+ q, k, v = torch.split(c_attn_bias, [embed_dim, head_dim, head_dim], dim=0)
89
+ # duplicate k, v along the first axis (head_dim, hidden_dim) -> (n_heads * head_dim, hidden_dim)
90
+ k = torch.tile(k, (config.n_head,))
91
+ v = torch.tile(v, (config.n_head,))
92
+ state_dict[f"transformer.layers.{d}.mixer.Wqkv.bias"] = torch.cat((q, k, v), dim=0)
93
+
94
+ def key_mapping_attn(key):
95
+ key = re.sub(
96
+ r"^transformer.h.(\d+).attn.c_proj.weight",
97
+ r"transformer.layers.\1.mixer.out_proj.weight",
98
+ key,
99
+ )
100
+ key = re.sub(
101
+ r"^transformer.h.(\d+).attn.c_proj.bias",
102
+ r"transformer.layers.\1.mixer.out_proj.bias",
103
+ key,
104
+ )
105
+ return key
106
+
107
+ state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
108
+
109
+ return state_dict
110
+
111
+
112
+ def inv_remap_state_dict_hf_bigcode(state_dict, config: PretrainedConfig):
113
+ """
114
+ Map the state_dict of a flash_attn model to be Huggingface BigCode compatible.
115
+
116
+ This function is meant to be the inverse of remap_state_dict_hf_bigcode.
117
+ """
118
+
119
+ # Word embedding and position embeddings
120
+ def inv_key_mapping_pos_emb(key):
121
+ return re.sub(r"^transformer.embeddings.position_embeddings.", "transformer.wpe.", key)
122
+
123
+ state_dict = OrderedDict((inv_key_mapping_pos_emb(k), v) for k, v in state_dict.items())
124
+ word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight")
125
+
126
+ word_embeddings = word_embeddings[:, : config.vocab_size]
127
+ state_dict["transformer.wte.weight"] = word_embeddings
128
+ state_dict["lm_head.weight"] = word_embeddings
129
+
130
+ # LayerNorm
131
+ def inv_key_mapping_ln(key):
132
+ key = re.sub(r"^transformer.ln_f.(weight|bias)", r"transformer.ln_f.\1", key)
133
+ key = re.sub(
134
+ r"^transformer.layers.(\d+).norm(1|2).(weight|bias)",
135
+ r"transformer.h.\1.ln_\2.\3",
136
+ key,
137
+ )
138
+ return key
139
+
140
+ state_dict = OrderedDict((inv_key_mapping_ln(k), v) for k, v in state_dict.items())
141
+
142
+ # MLPs
143
+ def inv_key_mapping_mlp(key):
144
+ key = re.sub(
145
+ r"^transformer.layers.(\d+).mlp.fc1.weight",
146
+ r"transformer.h.\1.mlp.c_fc.weight",
147
+ key,
148
+ )
149
+ key = re.sub(
150
+ r"^transformer.layers.(\d+).mlp.fc2.weight",
151
+ r"transformer.h.\1.mlp.c_proj.weight",
152
+ key,
153
+ )
154
+ key = re.sub(
155
+ r"^transformer.layers.(\d+).mlp.fc1.bias",
156
+ r"transformer.h.\1.mlp.c_fc.bias",
157
+ key,
158
+ )
159
+ key = re.sub(
160
+ r"^transformer.layers.(\d+).mlp.fc2.bias",
161
+ r"transformer.h.\1.mlp.c_proj.bias",
162
+ key,
163
+ )
164
+ return key
165
+
166
+ state_dict = OrderedDict((inv_key_mapping_mlp(k), v) for k, v in state_dict.items())
167
+
168
+ # Attention
169
+ for d in range(config.num_hidden_layers):
170
+ embed_dim = config.n_embd
171
+ head_dim = embed_dim // config.n_head
172
+
173
+ Wqkv_weight = state_dict.pop(f"transformer.layers.{d}.mixer.Wqkv.weight")
174
+ q, k, v = torch.split(
175
+ Wqkv_weight, [embed_dim, head_dim * config.n_head, head_dim * config.n_head], dim=0
176
+ )
177
+ c_attn_weight = torch.cat((q, k[:head_dim], v[:head_dim]), dim=0)
178
+ state_dict[f"transformer.h.{d}.attn.c_attn.weight"] = c_attn_weight
179
+
180
+ # Same deal with the bias
181
+ Wqkv_bias = state_dict.pop(f"transformer.layers.{d}.mixer.Wqkv.bias")
182
+ q, k, v = torch.split(
183
+ Wqkv_bias, [embed_dim, head_dim * config.n_head, head_dim * config.n_head], dim=0
184
+ )
185
+ c_attn_bias = torch.cat((q, k[:head_dim], v[:head_dim]), dim=0)
186
+ state_dict[f"transformer.h.{d}.attn.c_attn.bias"] = c_attn_bias
187
+
188
+ def inv_key_mapping_attn(key):
189
+ key = re.sub(
190
+ r"^transformer.layers.(\d+).mixer.out_proj.weight",
191
+ r"transformer.h.\1.attn.c_proj.weight",
192
+ key,
193
+ )
194
+ key = re.sub(
195
+ r"^transformer.layers.(\d+).mixer.out_proj.bias",
196
+ r"transformer.h.\1.attn.c_proj.bias",
197
+ key,
198
+ )
199
+ return key
200
+
201
+ state_dict = OrderedDict((inv_key_mapping_attn(k), v) for k, v in state_dict.items())
202
+
203
+ return state_dict
204
+
205
+
206
+ def bigcode_config_to_gpt2_config(bigcode_config: GPTBigCodeConfig) -> GPT2Config:
207
+ return GPT2Config(
208
+ activation_function=bigcode_config.activation_function,
209
+ attn_pdrop=bigcode_config.attn_pdrop,
210
+ bos_token_id=bigcode_config.bos_token_id,
211
+ embd_pdrop=bigcode_config.embd_pdrop,
212
+ eos_token_id=bigcode_config.eos_token_id,
213
+ initializer_range=bigcode_config.initializer_range,
214
+ layer_norm_epsilon=bigcode_config.layer_norm_epsilon,
215
+ max_batch_size=bigcode_config.max_batch_size,
216
+ max_sequence_length=bigcode_config.max_sequence_length,
217
+ model_type=bigcode_config.model_type,
218
+ multi_query=bigcode_config.multi_query,
219
+ n_embd=bigcode_config.n_embd,
220
+ n_head=bigcode_config.n_head,
221
+ n_inner=bigcode_config.n_inner,
222
+ n_layer=bigcode_config.n_layer,
223
+ n_positions=bigcode_config.n_positions,
224
+ resid_pdrop=bigcode_config.resid_pdrop,
225
+ scale_attn_weights=bigcode_config.scale_attn_weights,
226
+ summary_activation=bigcode_config.summary_activation,
227
+ summary_first_dropout=bigcode_config.summary_first_dropout,
228
+ summary_proj_to_labels=bigcode_config.summary_proj_to_labels,
229
+ summary_type=bigcode_config.summary_type,
230
+ summary_use_proj=bigcode_config.summary_use_proj,
231
+ use_cache=bigcode_config.use_cache,
232
+ vocab_size=bigcode_config.vocab_size,
233
+ )
.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/btlm.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+
3
+ import math
4
+ import json
5
+ import re
6
+ from pathlib import Path
7
+
8
+ from collections import OrderedDict
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+
13
+ from einops import rearrange
14
+ from transformers import GPT2Config, AutoConfig, PretrainedConfig
15
+
16
+
17
+ def remap_state_dict_hf_btlm(state_dict, config):
18
+ # Word embedding and position embedding
19
+ def key_mapping_pos_emb(key):
20
+ return re.sub(r"^transformer.wpe.", "transformer.embeddings.position_embeddings.", key)
21
+
22
+ if "transformer.wpe.weight" in state_dict:
23
+ state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items())
24
+ word_embeddings = state_dict.pop("transformer.wte.weight")
25
+ # It's possible that vocab_size is padded to be a multiple of 8, for example.
26
+ pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
27
+ vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
28
+ state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
29
+ word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
30
+ )
31
+ state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
32
+
33
+ # LayerNorm
34
+ def key_mapping_ln(key):
35
+ key = re.sub(r"^transformer.ln_f.(weight|bias)", r"transformer.ln_f.\1", key)
36
+ key = re.sub(r"^transformer.h.(\d+).ln_(1|2).(weight|bias)", r"transformer.layers.\1.norm\2.\3", key)
37
+ return key
38
+
39
+ state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
40
+
41
+ # MLP
42
+ for d in range(config.num_hidden_layers):
43
+ W1 = state_dict.pop(f"transformer.h.{d}.mlp.c_fc.weight")
44
+ W3 = state_dict.pop(f"transformer.h.{d}.mlp.c_fc2.weight")
45
+ state_dict[f"transformer.layers.{d}.mlp.fc1.weight"] = torch.cat([W1.t(), W3.t()], dim=0)
46
+ b1 = state_dict.pop(f"transformer.h.{d}.mlp.c_fc.bias")
47
+ b3 = state_dict.pop(f"transformer.h.{d}.mlp.c_fc2.bias")
48
+ state_dict[f"transformer.layers.{d}.mlp.fc1.bias"] = torch.cat([b1, b3], dim=0)
49
+ W2 = state_dict.pop(f"transformer.h.{d}.mlp.c_proj.weight")
50
+ state_dict[f"transformer.layers.{d}.mlp.fc2.weight"] = W2.t()
51
+
52
+ def key_mapping_mlp(key):
53
+ key = re.sub(r"^transformer.h.(\d+).mlp.c_proj.bias", r"transformer.layers.\1.mlp.fc2.bias", key)
54
+ return key
55
+
56
+ state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
57
+
58
+ # Attention
59
+ for d in range(config.num_hidden_layers):
60
+ Wqkv = state_dict.pop(f"transformer.h.{d}.attn.c_attn.weight")
61
+ state_dict[f"transformer.layers.{d}.mixer.Wqkv.weight"] = Wqkv.t()
62
+ Wout = state_dict.pop(f"transformer.h.{d}.attn.c_proj.weight")
63
+ state_dict[f"transformer.layers.{d}.mixer.out_proj.weight"] = Wout.t()
64
+ state_dict.pop(f"transformer.relative_pe.slopes") # We don't store the Alibi slopes
65
+
66
+ def key_mapping_attn(key):
67
+ key = re.sub(r"^transformer.h.(\d+).attn.c_attn.bias", r"transformer.layers.\1.mixer.Wqkv.bias", key)
68
+ key = re.sub(
69
+ r"^transformer.h.(\d+).attn.c_proj.bias", r"transformer.layers.\1.mixer.out_proj.bias", key
70
+ )
71
+ return key
72
+
73
+ state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
74
+
75
+ return state_dict
76
+
77
+
78
+ def btlm_config_to_gpt2_config(btlm_config: PretrainedConfig) -> GPT2Config:
79
+ return GPT2Config(
80
+ vocab_size=btlm_config.vocab_size,
81
+ n_positions=0 if btlm_config.position_embedding_type == "alibi" else btlm_config.n_positions,
82
+ n_embd=btlm_config.hidden_size,
83
+ n_layer=btlm_config.num_hidden_layers,
84
+ n_head=btlm_config.num_attention_heads,
85
+ n_inner=btlm_config.n_inner,
86
+ activation_function=btlm_config.activation_function,
87
+ resid_pdrop=btlm_config.resid_pdrop,
88
+ embd_pdrop=btlm_config.embd_pdrop,
89
+ attn_pdrop=btlm_config.attn_pdrop,
90
+ layer_norm_epsilon=btlm_config.layer_norm_epsilon,
91
+ initializer_range=btlm_config.initializer_range,
92
+ bos_token_id=btlm_config.bos_token_id,
93
+ eos_token_id=btlm_config.eos_token_id,
94
+ # These are new arguments not in the original GPT2Config
95
+ use_alibi=btlm_config.position_embedding_type == "alibi",
96
+ use_flash_attn=btlm_config.position_embedding_type == "alibi", # Alibi code path requires flash_attn
97
+ mup_width_scale=btlm_config.mup_width_scale,
98
+ mup_embeddings_multiplier=btlm_config.mup_embeddings_scale,
99
+ mup_output_multiplier=btlm_config.mup_output_alpha,
100
+ mup_scale_qk_dot_by_d=btlm_config.mup_scale_qk_dot_by_d,
101
+ mlp_multiple_of=1,
102
+ )
.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/falcon.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+
3
+ import math
4
+ import re
5
+ from collections import OrderedDict
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from einops import rearrange
10
+ from transformers import FalconConfig, GPT2Config
11
+
12
+
13
+ def remap_state_dict_hf_falcon(state_dict, config):
14
+ def key_mapping_layers(key):
15
+ return re.sub(r"^transformer.h.", "transformer.layers.", key)
16
+
17
+ state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
18
+ # Word embedding
19
+ def key_mapping_emb(key):
20
+ return re.sub(
21
+ r"^transformer.word_embeddings.", "transformer.embeddings.word_embeddings.", key
22
+ )
23
+
24
+ state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
25
+ word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight")
26
+ # It's possible that vocab_size is padded to be a multiple of 8, for example.
27
+ pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
28
+ vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
29
+ state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
30
+ word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
31
+ )
32
+ if getattr(config, "tie_word_embeddings"):
33
+ state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
34
+ else:
35
+ output_embeddings = state_dict.pop("lm_head.weight")
36
+ # It's possible that vocab_size is padded to be a multiple of 8, for example.
37
+ state_dict["lm_head.weight"] = F.pad(
38
+ output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
39
+ )
40
+ output_embeddings_bias = state_dict.pop("lm_head.bias")
41
+ state_dict["lm_head.bias"] = F.pad(
42
+ output_embeddings_bias, (0, vocab_size - output_embeddings_bias.shape[0])
43
+ )
44
+
45
+ # LayerNorm
46
+ def key_mapping_ln(key):
47
+ key = re.sub(
48
+ r"^transformer.layers.(\d+).input_layernorm.", r"transformer.layers.\1.norm1.", key
49
+ )
50
+ key = re.sub(
51
+ r"^transformer.layers.(\d+).post_attention_layernorm.",
52
+ r"transformer.layers.\1.norm2.",
53
+ key,
54
+ )
55
+ key = re.sub(r"^transformer.layers.(\d+).ln_attn.", r"transformer.layers.\1.norm1.", key)
56
+ key = re.sub(r"^transformer.layers.(\d+).ln_mlp.", r"transformer.layers.\1.norm2.", key)
57
+ return key
58
+
59
+ state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
60
+
61
+ # MLP
62
+ def key_mapping_mlp(key):
63
+ key = re.sub(
64
+ r"^transformer.layers.(\d+).mlp.dense_h_to_4h.", r"transformer.layers.\1.mlp.fc1.", key
65
+ )
66
+ key = re.sub(
67
+ r"^transformer.layers.(\d+).mlp.dense_4h_to_h.", r"transformer.layers.\1.mlp.fc2.", key
68
+ )
69
+ return key
70
+
71
+ state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
72
+
73
+ def key_mapping_attn(key):
74
+ key = re.sub(
75
+ r"^transformer.layers.(\d+).self_attention.query_key_value.",
76
+ r"transformer.layers.\1.mixer.Wqkv.",
77
+ key,
78
+ )
79
+ key = re.sub(
80
+ r"^transformer.layers.(\d+).self_attention.dense.",
81
+ r"transformer.layers.\1.mixer.out_proj.",
82
+ key,
83
+ )
84
+ return key
85
+
86
+ state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
87
+ n_head = config.n_head
88
+ n_head_kv = getattr(config, "n_head_kv", 1)
89
+ headdim = config.hidden_size // n_head
90
+ for l in range(config.n_layer):
91
+ # The weights are stored in a different layout compared to our implementation
92
+ Wqkv = rearrange(
93
+ state_dict.pop(f"transformer.layers.{l}.mixer.Wqkv.weight"),
94
+ "(group ratio headdim) ... -> group ratio headdim ...",
95
+ ratio=n_head // n_head_kv + 2,
96
+ headdim=headdim,
97
+ )
98
+ Wq = rearrange(Wqkv[:, :-2], "group ratio headdim ... -> (group ratio headdim) ...")
99
+ Wk = rearrange(Wqkv[:, [-2]], "group ratio headdim ... -> (group ratio headdim) ...")
100
+ Wv = rearrange(Wqkv[:, [-1]], "group ratio headdim ... -> (group ratio headdim) ...")
101
+ state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat([Wq, Wk, Wv], dim=0)
102
+
103
+ return state_dict
104
+
105
+
106
+ def falcon_config_to_gpt2_config(falcon_config: FalconConfig) -> GPT2Config:
107
+ # The 40b config uses "n_head_kv" instead of "num_kv_heads"
108
+ n_head_kv = getattr(
109
+ falcon_config,
110
+ "n_head_kv",
111
+ 1 if getattr(falcon_config, "multi_query", False) else falcon_config.n_head,
112
+ )
113
+ # HACK: the 40b config has 2 LN per layer instead of 1, but that's not reflected in the config.
114
+ # So we have to infer it from the number of heads in the key/value block
115
+ parallel_block_tied_norm = n_head_kv == 1
116
+ return GPT2Config(
117
+ vocab_size=falcon_config.vocab_size,
118
+ n_positions=0, # No absolute position embedding
119
+ n_embd=falcon_config.hidden_size,
120
+ n_layer=falcon_config.n_layer,
121
+ n_head=falcon_config.n_head,
122
+ n_inner=falcon_config.hidden_size * 4,
123
+ activation_function="gelu",
124
+ resid_pdrop=falcon_config.hidden_dropout,
125
+ embd_pdrop=0.0, # There doesn't seem to be any embedding dropout
126
+ attn_pdrop=falcon_config.attention_dropout,
127
+ layer_norm_epsilon=falcon_config.layer_norm_epsilon,
128
+ initializer_range=falcon_config.initializer_range,
129
+ bos_token_id=falcon_config.bos_token_id,
130
+ eos_token_id=falcon_config.eos_token_id,
131
+ # These are new arguments not in the original GPT2Config
132
+ parallel_block=falcon_config.parallel_attn,
133
+ n_head_kv=n_head_kv,
134
+ parallel_block_tied_norm=parallel_block_tied_norm,
135
+ rotary_emb_fraction=1.0,
136
+ rotary_emb_interleaved=False,
137
+ tie_word_embeddings=True,
138
+ qkv_proj_bias=falcon_config.bias,
139
+ out_proj_bias=falcon_config.bias,
140
+ mlp_fc1_bias=falcon_config.bias,
141
+ mlp_fc2_bias=falcon_config.bias,
142
+ lm_head_bias=False,
143
+ )
.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/gpt.py ADDED
@@ -0,0 +1,1080 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao.
2
+
3
+ import logging
4
+ import math
5
+ import re
6
+ from collections import OrderedDict, namedtuple
7
+ from collections.abc import Sequence
8
+ from functools import partial
9
+ from typing import Dict, List
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from einops import rearrange
15
+ from transformers import GPT2Config
16
+
17
+ from flash_attn.models.bigcode import remap_state_dict_hf_bigcode
18
+ from flash_attn.models.falcon import remap_state_dict_hf_falcon
19
+ from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox
20
+ from flash_attn.models.gptj import remap_state_dict_hf_gptj
21
+ from flash_attn.models.llama import remap_state_dict_hf_llama
22
+ from flash_attn.models.opt import remap_state_dict_hf_opt
23
+ from flash_attn.modules.block import Block, ParallelBlock
24
+ from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings
25
+ from flash_attn.modules.mha import MHA, ParallelMHA
26
+ from flash_attn.modules.mlp import (
27
+ FusedMLP,
28
+ GatedMlp,
29
+ Mlp,
30
+ ParallelFusedMLP,
31
+ ParallelGatedMlp,
32
+ ParallelMLP,
33
+ )
34
+ from flash_attn.ops.activations import sqrelu_fwd
35
+ from flash_attn.utils.distributed import (
36
+ all_gather,
37
+ all_gather_raw,
38
+ get_dim_for_local_rank,
39
+ sync_shared_params,
40
+ )
41
+ from flash_attn.utils.generation import GenerationMixin
42
+ from flash_attn.utils.pretrained import state_dict_from_pretrained
43
+
44
+ try:
45
+ from flash_attn.ops.fused_dense import ColumnParallelLinear
46
+ except ImportError:
47
+ ColumnParallelLinear = None
48
+
49
+ try:
50
+ from flash_attn.ops.triton.mlp import FusedDenseSqreluDense
51
+ except ImportError:
52
+ FusedDenseSqreluDense = None
53
+
54
+ try:
55
+ from flash_attn.ops.triton.layer_norm import layer_norm_fn, RMSNorm
56
+ except ImportError:
57
+ layer_norm_fn, RMSNorm = None, None
58
+
59
+ logger = logging.getLogger(__name__)
60
+
61
+
62
+ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dtype=None):
63
+ factory_kwargs = {"device": device, "dtype": dtype}
64
+ head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
65
+ attn_scale_power = 0.5 if not getattr(config, "mup_scale_qk_dot_by_d", False) else 1.0
66
+ softmax_scale = 1.0 if not config.scale_attn_weights else (head_dim ** (-attn_scale_power))
67
+ softmax_scale *= getattr(config, "mup_attn_multiplier", 1.0)
68
+ if config.scale_attn_by_inverse_layer_idx:
69
+ assert layer_idx is not None
70
+ softmax_scale /= float(layer_idx + 1)
71
+ dwconv = getattr(config, "attn_dwconv", False)
72
+ if dwconv:
73
+ assert process_group is None, "TensorParallel MHA does not support dwconv yet"
74
+ qkv_proj_bias = getattr(config, "qkv_proj_bias", True)
75
+ out_proj_bias = getattr(config, "out_proj_bias", True)
76
+ rotary_emb_dim = int(getattr(config, "rotary_emb_fraction", 0.0) * head_dim)
77
+ rotary_emb_base = getattr(config, "rotary_emb_base", 10000.0)
78
+ rotary_emb_scale_base = getattr(config, "rotary_emb_scale_base", None)
79
+ rotary_emb_interleaved = getattr(config, "rotary_emb_interleaved", False)
80
+ use_alibi = getattr(config, "use_alibi", False)
81
+ window_size = getattr(config, "window_size", (-1, -1))
82
+ use_flash_attn = getattr(config, "use_flash_attn", False)
83
+ fused_bias_fc = getattr(config, "fused_bias_fc", False)
84
+ if not fused_bias_fc:
85
+ assert process_group is None, "TensorParallel MHA requires fused_bias_fc"
86
+ mha_cls = MHA if process_group is None else ParallelMHA
87
+ serial_kwargs = (
88
+ {"fused_bias_fc": fused_bias_fc, "dwconv": dwconv} if process_group is None else {}
89
+ )
90
+ parallel_kwargs = (
91
+ {
92
+ "process_group": process_group,
93
+ "sequence_parallel": getattr(config, "sequence_parallel", True),
94
+ }
95
+ if process_group is not None
96
+ else {}
97
+ )
98
+ num_heads_kv = getattr(config, "n_head_kv", None)
99
+ mixer_cls = partial(
100
+ mha_cls,
101
+ num_heads=config.num_attention_heads,
102
+ num_heads_kv=num_heads_kv,
103
+ qkv_proj_bias=qkv_proj_bias,
104
+ out_proj_bias=out_proj_bias,
105
+ dropout=config.attn_pdrop,
106
+ softmax_scale=softmax_scale,
107
+ causal=True,
108
+ layer_idx=layer_idx,
109
+ rotary_emb_dim=rotary_emb_dim,
110
+ rotary_emb_base=rotary_emb_base,
111
+ rotary_emb_scale_base=rotary_emb_scale_base,
112
+ rotary_emb_interleaved=rotary_emb_interleaved,
113
+ use_alibi=use_alibi,
114
+ window_size=window_size,
115
+ use_flash_attn=use_flash_attn,
116
+ **serial_kwargs,
117
+ **parallel_kwargs,
118
+ **factory_kwargs,
119
+ )
120
+ return mixer_cls
121
+
122
+
123
+ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtype=None):
124
+ factory_kwargs = {"device": device, "dtype": dtype}
125
+ mlp_fc1_bias = getattr(config, "mlp_fc1_bias", True)
126
+ mlp_fc2_bias = getattr(config, "mlp_fc2_bias", True)
127
+ fused_mlp = getattr(config, "fused_mlp", False)
128
+ if fused_mlp:
129
+ assert config.activation_function in [
130
+ "gelu_new",
131
+ "gelu_fast",
132
+ "gelu_approx",
133
+ "gelu_pytorch_tanh",
134
+ "relu",
135
+ "sqrelu",
136
+ ]
137
+ fused_dense_sqrelu_dense = getattr(config, "fused_dense_sqrelu_dense", False)
138
+ if fused_dense_sqrelu_dense:
139
+ assert config.activation_function == "sqrelu", (
140
+ "fused_dense_sqrelu_dense only " "supports approximate activation_function sqrelu"
141
+ )
142
+ assert not (fused_dense_sqrelu_dense and fused_mlp)
143
+ if not fused_mlp and not fused_dense_sqrelu_dense:
144
+ assert config.activation_function in [
145
+ "gelu",
146
+ "gelu_new",
147
+ "gelu_fast",
148
+ "gelu_approx",
149
+ "gelu_pytorch_tanh",
150
+ "relu",
151
+ "sqrelu",
152
+ "glu",
153
+ "swiglu",
154
+ "geglu",
155
+ ]
156
+ if config.activation_function in ["glu", "swiglu", "geglu"]:
157
+ activation = (
158
+ F.sigmoid
159
+ if config.activation_function == "glu"
160
+ else (F.silu if config.activation_function == "swiglu" else F.gelu)
161
+ )
162
+ mlp_cls = GatedMlp if process_group is None else ParallelGatedMlp
163
+ parallel_kwargs = (
164
+ {
165
+ "process_group": process_group,
166
+ "sequence_parallel": getattr(config, "sequence_parallel", True),
167
+ }
168
+ if process_group is not None
169
+ else {}
170
+ )
171
+ mlp_multiple_of = getattr(config, "mlp_multiple_of", 128)
172
+ mlp_cls = partial(
173
+ mlp_cls,
174
+ hidden_features=config.n_inner,
175
+ activation=activation,
176
+ bias1=mlp_fc1_bias,
177
+ bias2=mlp_fc2_bias,
178
+ multiple_of=mlp_multiple_of,
179
+ **parallel_kwargs,
180
+ **factory_kwargs,
181
+ )
182
+ else:
183
+ if config.activation_function == "relu":
184
+ activation = partial(F.relu, inplace=True)
185
+ elif config.activation_function == "sqrelu":
186
+ activation = sqrelu_fwd
187
+ else:
188
+ approximate = (
189
+ "tanh"
190
+ if config.activation_function
191
+ in ["gelu_new", "gelu_fast", "gelu_approx", "gelu_pytorch_tanh"]
192
+ else "none"
193
+ )
194
+ activation = partial(F.gelu, approximate=approximate)
195
+ mlp_cls = Mlp if process_group is None else ParallelMLP
196
+ parallel_kwargs = (
197
+ {
198
+ "process_group": process_group,
199
+ "sequence_parallel": getattr(config, "sequence_parallel", True),
200
+ }
201
+ if process_group is not None
202
+ else {}
203
+ )
204
+ mlp_cls = partial(
205
+ mlp_cls,
206
+ hidden_features=config.n_inner,
207
+ activation=activation,
208
+ bias1=mlp_fc1_bias,
209
+ bias2=mlp_fc2_bias,
210
+ **parallel_kwargs,
211
+ **factory_kwargs,
212
+ )
213
+ else:
214
+ mlp_checkpoint_lvl = getattr(config, "mlp_checkpoint_lvl", 0)
215
+ # mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
216
+ if isinstance(mlp_checkpoint_lvl, Sequence):
217
+ assert layer_idx is not None
218
+ mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
219
+ if fused_mlp:
220
+ if FusedMLP is None:
221
+ raise ImportError("fused_dense is not installed")
222
+ activation = (
223
+ "gelu_approx"
224
+ if config.activation_function
225
+ in ["gelu_new", "gelu_fast", "gelu_approx", "gelu_pytorch_tanh"]
226
+ else config.activation_function
227
+ )
228
+ mlp_cls = FusedMLP if process_group is None else ParallelFusedMLP
229
+ parallel_kwargs = (
230
+ {
231
+ "process_group": process_group,
232
+ "sequence_parallel": getattr(config, "sequence_parallel", True),
233
+ }
234
+ if process_group is not None
235
+ else {}
236
+ )
237
+ mlp_cls = partial(
238
+ mlp_cls,
239
+ hidden_features=config.n_inner,
240
+ activation=activation,
241
+ checkpoint_lvl=mlp_checkpoint_lvl,
242
+ bias1=mlp_fc1_bias,
243
+ bias2=mlp_fc2_bias,
244
+ **parallel_kwargs,
245
+ **factory_kwargs,
246
+ )
247
+ elif fused_dense_sqrelu_dense:
248
+ if process_group is not None:
249
+ assert fused_mlp, "Tensor Parallel is not implemented for FusedDenseSqreluDense"
250
+ assert FusedDenseSqreluDense is not None
251
+ mlp_cls = partial(
252
+ FusedDenseSqreluDense,
253
+ hidden_features=config.n_inner,
254
+ checkpoint_lvl=mlp_checkpoint_lvl,
255
+ **factory_kwargs,
256
+ )
257
+ else:
258
+ raise RuntimeError("MLP type not supported")
259
+ return mlp_cls
260
+
261
+
262
+ def create_block(config, layer_idx=None, process_group=None, device=None, dtype=None):
263
+ factory_kwargs = {"device": device, "dtype": dtype}
264
+ sequence_parallel = getattr(config, "sequence_parallel", True)
265
+ mixer_cls = create_mixer_cls(config, layer_idx, process_group=process_group, **factory_kwargs)
266
+ mlp_cls = create_mlp_cls(config, layer_idx, process_group=process_group, **factory_kwargs)
267
+ use_rms_norm = getattr(config, "rms_norm", False)
268
+ norm_cls = partial(
269
+ nn.LayerNorm if not use_rms_norm else RMSNorm,
270
+ eps=config.layer_norm_epsilon,
271
+ **factory_kwargs,
272
+ )
273
+ # TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
274
+ residual_in_fp32 = getattr(config, "residual_in_fp32", False)
275
+ resid_dropout1 = config.resid_pdrop if layer_idx is None or layer_idx > 0 else config.embd_pdrop
276
+ prenorm = getattr(config, "prenorm", True)
277
+ parallel_block = getattr(config, "parallel_block", False)
278
+ if not parallel_block:
279
+ block = Block(
280
+ config.hidden_size,
281
+ mixer_cls,
282
+ mlp_cls,
283
+ norm_cls=norm_cls,
284
+ prenorm=prenorm,
285
+ resid_dropout1=resid_dropout1,
286
+ resid_dropout2=config.resid_pdrop,
287
+ fused_dropout_add_ln=getattr(config, "fused_dropout_add_ln", False),
288
+ residual_in_fp32=residual_in_fp32,
289
+ sequence_parallel=sequence_parallel and process_group is not None,
290
+ mark_shared_params=process_group is not None,
291
+ )
292
+ else:
293
+ assert prenorm
294
+ block = ParallelBlock(
295
+ config.hidden_size,
296
+ mixer_cls,
297
+ mlp_cls,
298
+ norm_cls=norm_cls,
299
+ resid_dropout1=resid_dropout1,
300
+ resid_dropout2=config.resid_pdrop,
301
+ tied_norm=getattr(config, "parallel_block_tied_norm", False),
302
+ fused_dropout_add_ln=getattr(config, "fused_dropout_add_ln", False),
303
+ residual_in_fp32=residual_in_fp32,
304
+ sequence_parallel=sequence_parallel and process_group is not None,
305
+ mark_shared_params=process_group is not None,
306
+ )
307
+ block.layer_idx = layer_idx
308
+ return block
309
+
310
+
311
+ class GPTPreTrainedModel(nn.Module):
312
+ """An abstract class to handle weights initialization and
313
+ a simple interface for dowloading and loading pretrained models.
314
+ """
315
+
316
+ def __init__(self, config, *inputs, **kwargs):
317
+ super().__init__()
318
+ if not isinstance(config, GPT2Config):
319
+ raise ValueError(
320
+ "Parameter config in `{}(config)` should be an instance of class `GPT2Config`. "
321
+ "To create a model from a Google pretrained model use "
322
+ "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
323
+ self.__class__.__name__, self.__class__.__name__
324
+ )
325
+ )
326
+ self.config = config
327
+
328
+ @classmethod
329
+ def from_pretrained(
330
+ cls,
331
+ model_name,
332
+ config,
333
+ *args,
334
+ strict=True,
335
+ device=None,
336
+ dtype=None,
337
+ world_size=1,
338
+ rank=0,
339
+ **kwargs,
340
+ ):
341
+ """
342
+ Instantiate a GPTPreTrainedModel from a pre-trained model file or a pytorch state dict.
343
+ Download and cache the pre-trained model file if needed.
344
+ """
345
+ # Instantiate model.
346
+ model = cls(config, *args, device=device, dtype=dtype, **kwargs)
347
+ # Load state_dict in cpu because we already initialized the model in GPU, and we don't
348
+ # want extra stuff taking up more GPU memory
349
+ state_dict = state_dict_from_pretrained(model_name, device="cpu", dtype=dtype)
350
+ if model_name.startswith("gpt2"):
351
+ state_dict = remap_state_dict_hf_gpt2(state_dict, config)
352
+ elif model_name.startswith("facebook/opt"):
353
+ state_dict = remap_state_dict_hf_opt(state_dict, config)
354
+ elif model_name.startswith("EleutherAI/gpt-j-") or model_name.startswith(
355
+ "togethercomputer/GPT-JT-"
356
+ ):
357
+ state_dict = remap_state_dict_hf_gptj(state_dict, config)
358
+ elif (
359
+ model_name.startswith("EleutherAI/gpt-neox-")
360
+ or model_name.startswith("EleutherAI/pythia-")
361
+ or model_name.startswith("togethercomputer/RedPajama-INCITE-")
362
+ ):
363
+ state_dict = remap_state_dict_hf_gpt_neox(state_dict, config)
364
+ elif model_name.startswith("tiiuae/falcon-"):
365
+ state_dict = remap_state_dict_hf_falcon(state_dict, config)
366
+ elif model_name.startswith("meta-llama/Llama-"):
367
+ state_dict = remap_state_dict_hf_llama(state_dict, config)
368
+ elif model_name.startswith("bigcode/") or model_name.startswith("WizardLM/"):
369
+ state_dict = remap_state_dict_hf_bigcode(state_dict, config)
370
+ else:
371
+ raise NotImplementedError(f"Model {model_name} not supported")
372
+ if world_size > 1:
373
+ state_dict = shard_state_dict_tp(state_dict, config, world_size, rank)
374
+ load_return = model.load_state_dict(state_dict, strict=strict)
375
+ logger.info(load_return)
376
+ return model
377
+
378
+
379
+ # https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
380
+ def _init_weights(
381
+ module, n_layer, initializer_range=0.02, mup_width_scale=1.0, rescale_prenorm_residual=True
382
+ ):
383
+ mup_init_scale = math.sqrt(mup_width_scale)
384
+ if isinstance(module, nn.Linear):
385
+ nn.init.normal_(module.weight, std=initializer_range * mup_init_scale)
386
+ optim_cfg = getattr(module.weight, "_optim", {})
387
+ optim_cfg.update({"lr_multiplier": mup_width_scale})
388
+ setattr(module.weight, "_optim", optim_cfg)
389
+ if module.bias is not None:
390
+ nn.init.zeros_(module.bias)
391
+ elif isinstance(module, nn.Embedding):
392
+ nn.init.normal_(module.weight, std=initializer_range)
393
+
394
+ if rescale_prenorm_residual:
395
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
396
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
397
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
398
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
399
+ #
400
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
401
+ for name, p in module.named_parameters():
402
+ if name in ["out_proj.weight", "fc2.weight"]:
403
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
404
+ nn.init.normal_(
405
+ p, mean=0.0, std=initializer_range * mup_init_scale / math.sqrt(2 * n_layer)
406
+ )
407
+
408
+
409
+ class GPTModel(GPTPreTrainedModel):
410
+ def __init__(self, config: GPT2Config, process_group=None, device=None, dtype=None):
411
+ super().__init__(config)
412
+ factory_kwargs = {"device": device, "dtype": dtype}
413
+ self.process_group = process_group
414
+ self.sequence_parallel = getattr(config, "sequence_parallel", True)
415
+ assert config.activation_function in [
416
+ "gelu",
417
+ "gelu_new",
418
+ "gelu_fast",
419
+ "gelu_approx",
420
+ "gelu_pytorch_tanh",
421
+ "relu",
422
+ "sqrelu",
423
+ "glu",
424
+ "swiglu",
425
+ "geglu",
426
+ ]
427
+ pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
428
+ vocab_size = (
429
+ math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
430
+ )
431
+ self.embeddings_multiplier = getattr(config, "mup_embeddings_multiplier", 1.0)
432
+ # TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
433
+ self.residual_in_fp32 = getattr(config, "residual_in_fp32", False)
434
+ # These 2 options are for OPT-350m
435
+ self.prenorm = getattr(config, "prenorm", True)
436
+ use_rms_norm = getattr(config, "rms_norm", False)
437
+ word_embed_proj_dim = getattr(config, "word_embed_proj_dim", None)
438
+ # For GPT-J, GPT-NeoX
439
+ self.parallel_block = getattr(config, "parallel_block", False)
440
+
441
+ if process_group is None:
442
+ self.embeddings = GPT2Embeddings(
443
+ config.hidden_size,
444
+ vocab_size,
445
+ config.max_position_embeddings,
446
+ word_embed_proj_dim=word_embed_proj_dim,
447
+ **factory_kwargs,
448
+ )
449
+ else:
450
+ self.embeddings = ParallelGPT2Embeddings(
451
+ config.hidden_size,
452
+ vocab_size,
453
+ config.max_position_embeddings,
454
+ process_group=process_group,
455
+ sequence_parallel=self.sequence_parallel,
456
+ **factory_kwargs,
457
+ )
458
+
459
+ # We change the order of dropout, residual and layer norm:
460
+ # Instead of LN -> Attn / MLP -> Dropout -> Add, we do:
461
+ # Dropout -> Add -> LN -> Attn / MLP, returning both the residual branch (output of Add) and
462
+ # the main branch (output of MLP). The model definition is unchanged, but the mapping of the
463
+ # nn.Dropout probabilities are changed.
464
+ # This is for performance reason: we can fuse dropout + add + layer_norm.
465
+ self.layers = nn.ModuleList(
466
+ [
467
+ create_block(config, layer_idx=i, process_group=process_group, **factory_kwargs)
468
+ for i in range(config.num_hidden_layers)
469
+ ]
470
+ )
471
+ rotary_emb_fraction = getattr(config, "rotary_emb_fraction", 0.0)
472
+ if rotary_emb_fraction > 0.0: # Tie all the RotaryEmbedding modules to share the same cos/sin cache
473
+ for layer in self.layers[1:]:
474
+ layer.mixer.rotary_emb = self.layers[0].mixer.rotary_emb
475
+
476
+ self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
477
+ if self.fused_dropout_add_ln:
478
+ if layer_norm_fn is None:
479
+ raise ImportError("Triton is not installed")
480
+ if self.prenorm:
481
+ self.drop_f = nn.Dropout(config.resid_pdrop)
482
+ norm_cls = nn.LayerNorm if not use_rms_norm else RMSNorm
483
+ self.ln_f = norm_cls(
484
+ config.hidden_size, eps=config.layer_norm_epsilon, **factory_kwargs
485
+ )
486
+ if process_group is not None:
487
+ for p in self.ln_f.parameters():
488
+ # Mark the norm parameters as "shared_params" so that we sync their values at init.
489
+ p._shared_params = True
490
+ # Mark the norm params as "sequence_parallel" so we run all-reduce on their grads.
491
+ if self.sequence_parallel:
492
+ p._sequence_parallel = True
493
+
494
+ self.apply(
495
+ partial(
496
+ _init_weights,
497
+ n_layer=config.num_hidden_layers,
498
+ initializer_range=config.initializer_range,
499
+ mup_width_scale=getattr(config, "mup_width_scale", 1.0),
500
+ )
501
+ )
502
+ self.tie_weights()
503
+
504
+ def tie_weights(self):
505
+ if self.process_group is not None:
506
+ sync_shared_params(self, self.process_group)
507
+
508
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
509
+ return {
510
+ i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
511
+ for i, layer in enumerate(self.layers)
512
+ }
513
+
514
+ def forward(self, input_ids, position_ids=None, inference_params=None):
515
+ # If using Tensor Parallel with sequence parallel, we combine the batch and the seqlen
516
+ # dimensions so that we can split on it easily, in case of small batch size.
517
+ # Only the attention layers need to know the seqlen.
518
+ embedding_kwargs = (
519
+ {"combine_batch_seqlen_dim": True}
520
+ if self.process_group is not None and self.sequence_parallel
521
+ else {}
522
+ )
523
+ hidden_states = self.embeddings(input_ids, position_ids=position_ids, **embedding_kwargs)
524
+ if self.embeddings_multiplier != 1.0:
525
+ hidden_states = hidden_states * self.embeddings_multiplier
526
+ if self.parallel_block:
527
+ hidden_states2 = None
528
+ residual = None
529
+ mixer_kwargs = (
530
+ {"seqlen": input_ids.shape[1]}
531
+ if self.process_group is not None and self.sequence_parallel
532
+ else {}
533
+ )
534
+ if inference_params is not None:
535
+ mixer_kwargs["inference_params"] = inference_params
536
+ for layer in self.layers:
537
+ if self.prenorm:
538
+ if not self.parallel_block:
539
+ hidden_states, residual = layer(
540
+ hidden_states, residual, mixer_kwargs=mixer_kwargs
541
+ )
542
+ else:
543
+ hidden_states, hidden_states2, residual = layer(
544
+ hidden_states, hidden_states2, residual, mixer_kwargs=mixer_kwargs
545
+ )
546
+ else:
547
+ hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
548
+ if self.prenorm:
549
+ if not self.fused_dropout_add_ln:
550
+ dropped = self.drop_f(hidden_states)
551
+ if not self.parallel_block:
552
+ residual = (dropped + residual) if residual is not None else dropped
553
+ else:
554
+ dropped2 = self.drop_f(hidden_states2)
555
+ residual = (
556
+ (residual + dropped + dropped2)
557
+ if residual is not None
558
+ else dropped + dropped2
559
+ )
560
+ hidden_states = self.ln_f(residual.to(dtype=self.ln_f.weight.dtype))
561
+ else:
562
+ # Set prenorm=False here since we don't need the residual
563
+ hidden_states = layer_norm_fn(
564
+ hidden_states,
565
+ self.ln_f.weight,
566
+ self.ln_f.bias,
567
+ residual=residual,
568
+ x1=None if not self.parallel_block else hidden_states2,
569
+ eps=self.ln_f.eps,
570
+ dropout_p=self.drop_f.p if self.training else 0.0,
571
+ prenorm=False,
572
+ is_rms_norm=isinstance(self.ln_f, RMSNorm)
573
+ )
574
+ return hidden_states
575
+
576
+
577
+ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
578
+ def __init__(self, config: GPT2Config, process_group=None, device=None, dtype=None):
579
+ factory_kwargs = {"device": device, "dtype": dtype}
580
+ super().__init__(config)
581
+ self.process_group = process_group
582
+ self.transformer = GPTModel(config, process_group=process_group, **factory_kwargs)
583
+ self.tie_word_embeddings = getattr(config, "tie_word_embeddings", True)
584
+ lm_head_bias = getattr(config, "lm_head_bias", False)
585
+ pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
586
+ vocab_size = (
587
+ math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
588
+ )
589
+ # This option is for OPT-350m
590
+ word_embed_proj_dim = getattr(config, "word_embed_proj_dim", None)
591
+ embed_dim = config.n_embd if word_embed_proj_dim is None else word_embed_proj_dim
592
+ if word_embed_proj_dim is not None:
593
+ self.project_out = nn.Linear(config.n_embd, embed_dim, bias=False, **factory_kwargs)
594
+ else:
595
+ self.project_out = None
596
+ mup_width_scale = getattr(config, "mup_width_scale", 1.0)
597
+ mup_output_multiplier = getattr(config, "mup_output_multiplier", 1.0)
598
+ self.output_scale = mup_output_multiplier * mup_width_scale
599
+ if process_group is None:
600
+ self.lm_head = nn.Linear(embed_dim, vocab_size, bias=lm_head_bias, **factory_kwargs)
601
+ else:
602
+ if ColumnParallelLinear is None:
603
+ raise ImportError("fused_dense_lib is not installed")
604
+ self.lm_head = ColumnParallelLinear(
605
+ embed_dim,
606
+ vocab_size,
607
+ process_group,
608
+ bias=lm_head_bias,
609
+ sequence_parallel=getattr(config, "sequence_parallel", True),
610
+ **factory_kwargs,
611
+ )
612
+ self.norm_head = getattr(config, "norm_head", False)
613
+ # Initialize weights and apply final processing
614
+ self.apply(
615
+ partial(
616
+ _init_weights,
617
+ n_layer=config.num_hidden_layers,
618
+ initializer_range=config.initializer_range,
619
+ mup_width_scale=mup_width_scale,
620
+ )
621
+ )
622
+ self.tie_weights()
623
+
624
+ def tie_weights(self):
625
+ if self.tie_word_embeddings:
626
+ self.lm_head.weight = self.transformer.embeddings.word_embeddings.weight
627
+ if self.process_group is not None:
628
+ sync_shared_params(self, self.process_group)
629
+
630
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
631
+ return self.transformer.allocate_inference_cache(
632
+ batch_size, max_seqlen, dtype=dtype, **kwargs
633
+ )
634
+
635
+ def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0):
636
+ """
637
+ input_ids: (batch, seqlen) int tensor
638
+ inference_params: for generation. Adapted from Megatron-LM (and Apex)
639
+ https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
640
+ num_last_tokens: if > 0, only return the logits for the last n tokens
641
+ """
642
+ assert (
643
+ input_ids.ndim == 2
644
+ ), f"Expected `input_ids` to have shape [b, slen], but got shape {input_ids.shape}"
645
+ b, slen = input_ids.shape
646
+ hidden_states = self.transformer(
647
+ input_ids, position_ids=position_ids, inference_params=inference_params
648
+ )
649
+ if inference_params is not None:
650
+ assert hidden_states.ndim == 3, "sequence_parallel is not supported in generation mode"
651
+ if num_last_tokens > 0:
652
+ hidden_states = hidden_states[:, -num_last_tokens:]
653
+ if self.project_out is not None:
654
+ hidden_states = self.project_out(hidden_states)
655
+ if self.output_scale != 1.0:
656
+ hidden_states = hidden_states * self.output_scale
657
+ if not self.norm_head:
658
+ lm_logits = self.lm_head(hidden_states)
659
+ else:
660
+ lm_head_weight = F.normalize(self.lm_head.weight)
661
+ if isinstance(self.lm_head, ColumnParallelLinear) and self.lm_head.sequence_parallel:
662
+ hidden_states = all_gather(hidden_states, self.lm_head.process_group)
663
+ lm_logits = F.linear(hidden_states, lm_head_weight, bias=self.lm_head.bias)
664
+ # During inference, we want the full logit for sampling
665
+ if isinstance(self.lm_head, ColumnParallelLinear) and inference_params is not None:
666
+ lm_logits, _ = all_gather_raw(lm_logits, self.lm_head.process_group)
667
+ lm_logits = rearrange(lm_logits, "(n b) ... d -> b ... (n d)", b=b)
668
+ CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
669
+ return CausalLMOutput(logits=lm_logits)
670
+
671
+ def load_state_dict(self, state_dict, strict=True):
672
+ # Remapping from our checkpoints that used a different ordering of layers in the block
673
+ # Previous: Attn / MLP -> Dropout -> Add -> LN
674
+ # Current: Dropout -> Add -> LN -> Attn / MLP
675
+ if "transformer.ln_0.weight" in state_dict:
676
+ n_layers = len(self.transformer.layers)
677
+ ln_weight = state_dict.pop(f"transformer.layers.{n_layers - 1}.norm2.weight")
678
+ ln_bias = state_dict.pop(f"transformer.layers.{n_layers - 1}.norm2.bias")
679
+ state_dict["transformer.ln_f.weight"] = ln_weight
680
+ state_dict["transformer.ln_f.bias"] = ln_bias
681
+ for l in reversed(range(n_layers)):
682
+ ln_weight = state_dict.pop(f"transformer.layers.{l}.norm1.weight")
683
+ ln_bias = state_dict.pop(f"transformer.layers.{l}.norm1.bias")
684
+ state_dict[f"transformer.layers.{l}.norm2.weight"] = ln_weight
685
+ state_dict[f"transformer.layers.{l}.norm2.bias"] = ln_bias
686
+ if l > 0:
687
+ ln_weight = state_dict.pop(f"transformer.layers.{l - 1}.norm2.weight")
688
+ ln_bias = state_dict.pop(f"transformer.layers.{l - 1}.norm2.bias")
689
+ state_dict[f"transformer.layers.{l}.norm1.weight"] = ln_weight
690
+ state_dict[f"transformer.layers.{l}.norm1.bias"] = ln_bias
691
+ ln_weight = state_dict.pop("transformer.ln_0.weight")
692
+ ln_bias = state_dict.pop("transformer.ln_0.bias")
693
+ state_dict[f"transformer.layers.0.norm1.weight"] = ln_weight
694
+ state_dict[f"transformer.layers.0.norm1.bias"] = ln_bias
695
+ return super().load_state_dict(state_dict, strict=strict)
696
+
697
+
698
+ def shard_state_dict_tp(state_dict, config, world_size, rank):
699
+ """Convert the state_dict of a standard GPT model to the state_dict of a GPT model
700
+ with tensor parallel.
701
+
702
+ This function modifies state_dict in place.
703
+ """
704
+ pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
705
+ vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
706
+ assert vocab_size % world_size == 0
707
+ assert config.hidden_size % world_size == 0
708
+ inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
709
+ assert inner_dim % world_size == 0
710
+
711
+ n_head = config.n_head
712
+ n_head_kv = getattr(config, "n_head_kv", n_head)
713
+
714
+ embed_dim = config.hidden_size
715
+ head_dim = embed_dim // n_head
716
+
717
+ def shard_first_dim(state_dict, key):
718
+ if key in state_dict:
719
+ x = state_dict[key]
720
+ dim = x.shape[0] // world_size
721
+ state_dict[key] = x[rank * dim : (rank + 1) * dim]
722
+
723
+ def shard_last_dim(state_dict, key, multiple_of=1):
724
+ if key in state_dict:
725
+ x = state_dict[key]
726
+ dim_each_rank = [
727
+ get_dim_for_local_rank(x.size(-1), world_size, local_rank, multiple_of)
728
+ for local_rank in range(world_size)
729
+ ]
730
+ beg, end = tuple(sum(dim_each_rank[:pos]) for pos in (rank, rank + 1))
731
+ state_dict[key] = x[..., beg:end]
732
+
733
+ def shard_gatedmlp_fc1_dim(state_dict, key):
734
+ if key in state_dict:
735
+ x = state_dict[key]
736
+ dim = x.shape[0] // world_size // 2
737
+ state_dict[key] = rearrange(
738
+ rearrange(x, "(two o) ... -> two o ...", two=2)[:, rank * dim : (rank + 1) * dim],
739
+ "two o ... -> (two o) ...",
740
+ )
741
+
742
+ def shard_qkv_headdim(state_dict, key):
743
+ if key in state_dict:
744
+ n_head_each_rank = [
745
+ get_dim_for_local_rank(n_head, world_size, local_rank)
746
+ for local_rank in range(world_size)
747
+ ]
748
+ n_head_kv_each_rank = [
749
+ get_dim_for_local_rank(n_head_kv, world_size, local_rank)
750
+ for local_rank in range(world_size)
751
+ ]
752
+
753
+ beg_n_head = sum(n_head_each_rank[:rank])
754
+ end_n_head = sum(n_head_each_rank[: rank + 1])
755
+
756
+ beg_n_head_kv = sum(n_head_kv_each_rank[:rank])
757
+ end_n_head_kv = sum(n_head_kv_each_rank[: rank + 1])
758
+
759
+ if n_head_kv == n_head:
760
+ x = rearrange(state_dict[key], "(three d) ... -> three d ...", three=3)
761
+ state_dict[key] = rearrange(
762
+ x[:, beg_n_head * head_dim : end_n_head * head_dim],
763
+ "three d ... -> (three d) ...",
764
+ )
765
+ else:
766
+ x = rearrange(
767
+ state_dict[key],
768
+ "(nheadqkv headdim) ... -> nheadqkv headdim ...",
769
+ nheadqkv=n_head + 2 * n_head_kv,
770
+ )
771
+ state_dict[key] = rearrange(
772
+ torch.cat(
773
+ [
774
+ x[beg_n_head:end_n_head],
775
+ x[n_head + beg_n_head_kv : n_head + end_n_head_kv],
776
+ x[
777
+ n_head
778
+ + n_head_kv
779
+ + beg_n_head_kv : n_head
780
+ + n_head_kv
781
+ + end_n_head_kv
782
+ ],
783
+ ],
784
+ dim=0,
785
+ ),
786
+ "nheadqkv headdim ... -> (nheadqkv headdim) ...",
787
+ )
788
+
789
+ shard_first_dim(state_dict, "transformer.embeddings.word_embeddings.weight")
790
+ if "lm_head.weight" in state_dict:
791
+ shard_first_dim(state_dict, "lm_head.weight")
792
+ if "transformer.embeddings.position_embeddings.weight" in state_dict:
793
+ shard_last_dim(state_dict, "transformer.embeddings.position_embeddings.weight")
794
+ for i in range(config.num_hidden_layers):
795
+ shard_qkv_headdim(state_dict, f"transformer.layers.{i}.mixer.Wqkv.weight")
796
+ shard_qkv_headdim(state_dict, f"transformer.layers.{i}.mixer.Wqkv.bias")
797
+ shard_last_dim(
798
+ state_dict, f"transformer.layers.{i}.mixer.out_proj.weight", multiple_of=head_dim
799
+ )
800
+ if rank != 0:
801
+ state_dict.pop(f"transformer.layers.{i}.mixer.out_proj.bias", None)
802
+ if config.activation_function in ["glu", "swiglu", "geglu"]:
803
+ shard_gatedmlp_fc1_dim(state_dict, f"transformer.layers.{i}.mlp.fc1.weight")
804
+ shard_gatedmlp_fc1_dim(state_dict, f"transformer.layers.{i}.mlp.fc1.bias")
805
+ else:
806
+ shard_first_dim(state_dict, f"transformer.layers.{i}.mlp.fc1.weight")
807
+ shard_first_dim(state_dict, f"transformer.layers.{i}.mlp.fc1.bias")
808
+ shard_last_dim(state_dict, f"transformer.layers.{i}.mlp.fc2.weight")
809
+ if rank != 0:
810
+ state_dict.pop(f"transformer.layers.{i}.mlp.fc2.bias", None)
811
+ return state_dict
812
+
813
+
814
+ def combine_state_dicts_tp(state_dicts: List[Dict[str, torch.Tensor]], config: GPT2Config):
815
+ """Convert the list of sharded state_dict of a GPT model with tensor parallel to
816
+ the state_dict of a standard GPT model.
817
+
818
+ This function is meant to be the "reverse" of shard_state_dict_tp.
819
+
820
+ Precondition:
821
+ - state_dicts should be ordered in the same way as the shards were created.
822
+ """
823
+ world_size = len(state_dicts)
824
+ keys = state_dicts[0].keys()
825
+ pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
826
+ vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
827
+ assert vocab_size % world_size == 0
828
+ assert config.hidden_size % world_size == 0
829
+ inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
830
+ assert inner_dim % world_size == 0
831
+ assert config.hidden_size % config.n_head == 0
832
+ headdim = config.hidden_size // config.n_head
833
+
834
+ # Sometimes the word embeddings are sharded on the 0th dim, sometimes on the 1st dim.
835
+ # vocab_size // world_size coordinates are nonzero.
836
+ def combine_word_embeddings(state_dicts, state_dict, key):
837
+ dim = 0 if state_dicts[0][key].shape[0] == vocab_size // world_size else 1
838
+ state_dict[key] = torch.cat([s[key] for s in state_dicts], dim=dim)
839
+
840
+ def combine_dim(state_dicts, state_dict, key, dim=-1):
841
+ if key in state_dict:
842
+ state_dict[key] = torch.cat([s[key] for s in state_dicts], dim=dim)
843
+
844
+ def combine_qkv_headdim(state_dicts, state_dict, key):
845
+ n_head = config.n_head
846
+ n_head_kv = getattr(config, "n_head_kv", n_head)
847
+ if key in state_dict:
848
+ if n_head_kv == n_head:
849
+ xs = [
850
+ rearrange(s[key], "(three d) ... -> three d ...", three=3) for s in state_dicts
851
+ ]
852
+ state_dict[key] = rearrange(torch.cat(xs, dim=1), "three d ... -> (three d) ...")
853
+ else:
854
+ n_head_each_rank = [
855
+ get_dim_for_local_rank(n_head, world_size, local_rank)
856
+ for local_rank in range(world_size)
857
+ ]
858
+ n_head_kv_each_rank = [
859
+ get_dim_for_local_rank(n_head_kv, world_size, local_rank)
860
+ for local_rank in range(world_size)
861
+ ]
862
+ xs = [
863
+ rearrange(
864
+ s[key],
865
+ "(nheadqkv headdim) ... -> nheadqkv headdim ...",
866
+ nheadqkv=rank_n_head + 2 * rank_n_head_kv,
867
+ headdim=headdim,
868
+ )
869
+ for s, rank_n_head, rank_n_head_kv in zip(
870
+ state_dicts, n_head_each_rank, n_head_kv_each_rank
871
+ )
872
+ ]
873
+ wq = torch.cat([x[: n_head_each_rank[rank]] for rank, x in enumerate(xs)], dim=0)
874
+ wk = torch.cat(
875
+ [
876
+ x[
877
+ n_head_each_rank[rank] : n_head_each_rank[rank]
878
+ + n_head_kv_each_rank[rank]
879
+ ]
880
+ for rank, x in enumerate(xs)
881
+ ],
882
+ dim=0,
883
+ )
884
+ wv = torch.cat(
885
+ [
886
+ x[n_head_each_rank[rank] + n_head_kv_each_rank[rank] :]
887
+ for rank, x in enumerate(xs)
888
+ ],
889
+ dim=0,
890
+ )
891
+ wqkv = torch.cat(
892
+ [wq, wk, wv],
893
+ dim=0,
894
+ )
895
+ state_dict[key] = rearrange(
896
+ wqkv,
897
+ "nheadqkv headdim ... -> (nheadqkv headdim) ...",
898
+ )
899
+
900
+ def combine_gated_mlp(state_dicts, state_dict, key):
901
+ if key in state_dict:
902
+ xs = [rearrange(s[key], "(two d) ... -> two d ...", two=2) for s in state_dicts]
903
+ state_dict[key] = rearrange(torch.cat(xs, dim=1), "two d ... -> (two d) ...")
904
+
905
+ state_dict = state_dicts[0].copy() # don't modify state_dict[0] inplace
906
+ combine_word_embeddings(
907
+ state_dicts, state_dict, "transformer.embeddings.word_embeddings.weight"
908
+ )
909
+ if "lm_head.weight" in state_dict:
910
+ combine_word_embeddings(state_dicts, state_dict, "lm_head.weight")
911
+ if "transformer.embeddings.position_embeddings.weight" in state_dict:
912
+ combine_dim(
913
+ state_dicts, state_dict, "transformer.embeddings.position_embeddings.weight", -1
914
+ )
915
+ mlp_combine_fn = (
916
+ combine_gated_mlp
917
+ if config.activation_function in ["glu", "swiglu", "geglu"]
918
+ else partial(combine_dim, dim=0)
919
+ )
920
+ for i in range(config.num_hidden_layers):
921
+ combine_qkv_headdim(state_dicts, state_dict, f"transformer.layers.{i}.mixer.Wqkv.weight")
922
+ combine_qkv_headdim(state_dicts, state_dict, f"transformer.layers.{i}.mixer.Wqkv.bias")
923
+ combine_dim(state_dicts, state_dict, f"transformer.layers.{i}.mixer.out_proj.weight", -1)
924
+ mlp_combine_fn(state_dicts, state_dict, f"transformer.layers.{i}.mlp.fc1.weight")
925
+ combine_dim(state_dicts, state_dict, f"transformer.layers.{i}.mlp.fc1.bias", 0)
926
+ combine_dim(state_dicts, state_dict, f"transformer.layers.{i}.mlp.fc2.weight", -1)
927
+ return state_dict
928
+
929
+
930
+ def remap_state_dict_hf_gpt2(state_dict, config):
931
+ # Word embedding and position embedding
932
+ def key_mapping_pos_emb(key):
933
+ return re.sub(r"^wpe.", "transformer.embeddings.position_embeddings.", key)
934
+
935
+ state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items())
936
+ word_embeddings = state_dict.pop("wte.weight")
937
+ # It's possible that vocab_size is padded to be a multiple of 8, for example.
938
+ pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
939
+ vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
940
+ state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
941
+ word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
942
+ )
943
+ state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
944
+
945
+ # LayerNorm
946
+ def key_mapping_ln(key):
947
+ key = re.sub(r"^ln_f.(weight|bias)", r"transformer.ln_f.\1", key)
948
+ key = re.sub(r"^h.(\d+).ln_(1|2).(weight|bias)", r"transformer.layers.\1.norm\2.\3", key)
949
+ return key
950
+
951
+ state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
952
+
953
+ # MLP
954
+ for d in range(config.num_hidden_layers):
955
+ W1 = state_dict.pop(f"h.{d}.mlp.c_fc.weight")
956
+ state_dict[f"transformer.layers.{d}.mlp.fc1.weight"] = W1.t()
957
+ W2 = state_dict.pop(f"h.{d}.mlp.c_proj.weight")
958
+ state_dict[f"transformer.layers.{d}.mlp.fc2.weight"] = W2.t()
959
+
960
+ def key_mapping_mlp(key):
961
+ key = re.sub(r"^h.(\d+).mlp.c_fc.bias", r"transformer.layers.\1.mlp.fc1.bias", key)
962
+ key = re.sub(r"^h.(\d+).mlp.c_proj.bias", r"transformer.layers.\1.mlp.fc2.bias", key)
963
+ return key
964
+
965
+ state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
966
+
967
+ # Attention
968
+ for d in range(config.num_hidden_layers):
969
+ state_dict.pop(f"h.{d}.attn.bias", None) # We don't store this bias
970
+ Wqkv = state_dict.pop(f"h.{d}.attn.c_attn.weight")
971
+ state_dict[f"transformer.layers.{d}.mixer.Wqkv.weight"] = Wqkv.t()
972
+ Wout = state_dict.pop(f"h.{d}.attn.c_proj.weight")
973
+ state_dict[f"transformer.layers.{d}.mixer.out_proj.weight"] = Wout.t()
974
+
975
+ def key_mapping_attn(key):
976
+ key = re.sub(r"^h.(\d+).attn.c_attn.bias", r"transformer.layers.\1.mixer.Wqkv.bias", key)
977
+ key = re.sub(
978
+ r"^h.(\d+).attn.c_proj.bias", r"transformer.layers.\1.mixer.out_proj.bias", key
979
+ )
980
+ return key
981
+
982
+ state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
983
+
984
+ return state_dict
985
+
986
+
987
+ def remap_state_dict_megatron(state_dict, config):
988
+ def key_mapping_transformer(key):
989
+ key = re.sub(r"^language_model.encoder.", "transformer.", key)
990
+ key = re.sub(r"^language_model.", "transformer.", key)
991
+ return key
992
+
993
+ state_dict = OrderedDict((key_mapping_transformer(k), v) for k, v in state_dict.items())
994
+
995
+ # Word embedding and position embedding
996
+ def key_mapping_pos_emb(key):
997
+ return re.sub(r"^wpe.", "transformer.embeddings.position_embeddings.", key)
998
+
999
+ state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items())
1000
+ word_embeddings = state_dict.pop("transformer.embedding.word_embeddings.weight")
1001
+ # It's possible that vocab_size is padded to be a multiple of 8, for example.
1002
+ pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
1003
+ vocab_size = (
1004
+ math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple) * pad_vocab_size_multiple
1005
+ )
1006
+ state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
1007
+ word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
1008
+ )
1009
+ state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
1010
+
1011
+ # LayerNorm
1012
+ def key_mapping_ln(key):
1013
+ key = re.sub(r"^transformer.final_layernorm.(weight|bias)", r"transformer.ln_f.\1", key)
1014
+ key = re.sub(
1015
+ r"^transformer.layers.(\d+).input_layernorm.(weight|bias)",
1016
+ r"transformer.layers.\1.norm1.\2",
1017
+ key,
1018
+ )
1019
+ key = re.sub(
1020
+ r"^transformer.layers.(\d+).post_attention_layernorm.(weight|bias)",
1021
+ r"transformer.layers.\1.norm2.\2",
1022
+ key,
1023
+ )
1024
+ return key
1025
+
1026
+ state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
1027
+
1028
+ # MLP
1029
+ def key_mapping_mlp(key):
1030
+ key = re.sub(
1031
+ r"^transformer.layers.(\d+).mlp.dense_h_to_4h.(weight|bias)",
1032
+ r"transformer.layers.\1.mlp.fc1.\2",
1033
+ key,
1034
+ )
1035
+ key = re.sub(
1036
+ r"^transformer.layers.(\d+).mlp.dense_4h_to_h.(weight|bias)",
1037
+ r"transformer.layers.\1.mlp.fc2.\2",
1038
+ key,
1039
+ )
1040
+ return key
1041
+
1042
+ state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
1043
+
1044
+ # Attention
1045
+ def key_mapping_attn(key):
1046
+ key = re.sub(
1047
+ r"^transformer.layers.(\d+).self_attention.rotary_emb.inv_freq",
1048
+ r"transformer.layers.\1.mixer.rotary_emb.inv_freq",
1049
+ key,
1050
+ )
1051
+ key = re.sub(
1052
+ r"^transformer.layers.(\d+).self_attention.query_key_value.(weight|bias)",
1053
+ r"transformer.layers.\1.mixer.Wqkv.\2",
1054
+ key,
1055
+ )
1056
+ key = re.sub(
1057
+ r"^transformer.layers.(\d+).self_attention.dense.(weight|bias)",
1058
+ r"transformer.layers.\1.mixer.out_proj.\2",
1059
+ key,
1060
+ )
1061
+ return key
1062
+
1063
+ state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
1064
+ # Megatron stores Wqkv as ((nheads 3 headdim), hidden_dim)
1065
+ # while we store Wqkv as ((3 nheads headdim), hidden_dim)
1066
+ headdim = config.hidden_size // config.num_attention_heads
1067
+ for d in range(config.num_hidden_layers):
1068
+ Wqkv = state_dict.pop(f"transformer.layers.{d}.mixer.Wqkv.weight")
1069
+ state_dict[f"transformer.layers.{d}.mixer.Wqkv.weight"] = rearrange(
1070
+ Wqkv,
1071
+ "(nheads three headdim) ... -> (three nheads headdim) ...",
1072
+ three=3,
1073
+ headdim=headdim,
1074
+ )
1075
+ bqkv = state_dict.pop(f"transformer.layers.{d}.mixer.Wqkv.bias")
1076
+ state_dict[f"transformer.layers.{d}.mixer.Wqkv.bias"] = rearrange(
1077
+ bqkv, "(nheads three headdim) -> (three nheads headdim)", three=3, headdim=headdim
1078
+ )
1079
+
1080
+ return state_dict
.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/gpt_neox.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+
3
+ import math
4
+ import re
5
+ from collections import OrderedDict
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from einops import rearrange
10
+ from transformers import GPT2Config, GPTNeoXConfig
11
+
12
+
13
+ def remap_state_dict_hf_gpt_neox(state_dict, config):
14
+ def key_mapping_layers(key):
15
+ return re.sub(r"^gpt_neox.", "transformer.", key)
16
+
17
+ state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
18
+ # Word embedding
19
+ def key_mapping_emb(key):
20
+ return re.sub(r"^transformer.embed_in.", "transformer.embeddings.word_embeddings.", key)
21
+
22
+ state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
23
+ word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight")
24
+ # It's possible that vocab_size is padded to be a multiple of 8, for example.
25
+ pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
26
+ vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
27
+ state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
28
+ word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
29
+ )
30
+ if getattr(config, "tie_word_embeddings", False):
31
+ state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
32
+ else:
33
+ output_embeddings = state_dict.pop("embed_out.weight")
34
+ # It's possible that vocab_size is padded to be a multiple of 8, for example.
35
+ state_dict["lm_head.weight"] = F.pad(
36
+ output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
37
+ )
38
+
39
+ # LayerNorm
40
+ def key_mapping_ln(key):
41
+ key = re.sub(r"^transformer.final_layer_norm.", r"transformer.ln_f.", key)
42
+ key = re.sub(
43
+ r"^transformer.layers.(\d+).input_layernorm.", r"transformer.layers.\1.norm1.", key
44
+ )
45
+ key = re.sub(
46
+ r"^transformer.layers.(\d+).post_attention_layernorm.",
47
+ r"transformer.layers.\1.norm2.",
48
+ key,
49
+ )
50
+ return key
51
+
52
+ state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
53
+
54
+ # MLP
55
+ def key_mapping_mlp(key):
56
+ key = re.sub(
57
+ r"^transformer.layers.(\d+).mlp.dense_h_to_4h.", r"transformer.layers.\1.mlp.fc1.", key
58
+ )
59
+ key = re.sub(
60
+ r"^transformer.layers.(\d+).mlp.dense_4h_to_h.", r"transformer.layers.\1.mlp.fc2.", key
61
+ )
62
+ return key
63
+
64
+ state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
65
+
66
+ # Attention
67
+ for l in range(config.n_layer):
68
+ # We don't store these biases
69
+ state_dict.pop(f"transformer.layers.{l}.attention.bias")
70
+ state_dict.pop(f"transformer.layers.{l}.attention.masked_bias")
71
+ # We don't store these
72
+ state_dict.pop(f"transformer.layers.{l}.attention.rotary_emb.inv_freq", None)
73
+ # GPT-NeoX stores Wqkv as ((nheads 3 headdim), hidden_dim)
74
+ # while we store Wqkv as ((3 nheads headdim), hidden_dim)
75
+ headdim = config.hidden_size // config.num_attention_heads
76
+ Wqkv = state_dict.pop(f"transformer.layers.{l}.attention.query_key_value.weight")
77
+ state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = rearrange(
78
+ Wqkv,
79
+ "(nheads three headdim) ... -> (three nheads headdim) ...",
80
+ three=3,
81
+ headdim=headdim,
82
+ )
83
+ bqkv = state_dict.pop(f"transformer.layers.{l}.attention.query_key_value.bias")
84
+ state_dict[f"transformer.layers.{l}.mixer.Wqkv.bias"] = rearrange(
85
+ bqkv, "(nheads three headdim) -> (three nheads headdim)", three=3, headdim=headdim
86
+ )
87
+
88
+ def key_mapping_attn(key):
89
+ key = re.sub(
90
+ r"^transformer.layers.(\d+).attention.dense.",
91
+ r"transformer.layers.\1.mixer.out_proj.",
92
+ key,
93
+ )
94
+ return key
95
+
96
+ state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
97
+
98
+ return state_dict
99
+
100
+
101
+ def gpt_neox_config_to_gpt2_config(gpt_neox_config: GPTNeoXConfig) -> GPT2Config:
102
+ assert gpt_neox_config.rotary_emb_base == 10000
103
+ return GPT2Config(
104
+ vocab_size=gpt_neox_config.vocab_size,
105
+ n_positions=0, # No absolute position embedding
106
+ n_embd=gpt_neox_config.hidden_size,
107
+ n_layer=gpt_neox_config.num_hidden_layers,
108
+ n_head=gpt_neox_config.num_attention_heads,
109
+ n_inner=gpt_neox_config.intermediate_size,
110
+ activation_function=gpt_neox_config.hidden_act,
111
+ resid_pdrop=0.0, # No dropout
112
+ embd_pdrop=0.0,
113
+ attn_pdrop=0.0,
114
+ layer_norm_epsilon=gpt_neox_config.layer_norm_eps,
115
+ initializer_range=gpt_neox_config.initializer_range,
116
+ bos_token_id=gpt_neox_config.bos_token_id,
117
+ eos_token_id=gpt_neox_config.eos_token_id,
118
+ # These are new arguments not in the original GPT2Config
119
+ prenorm=True,
120
+ parallel_block=gpt_neox_config.use_parallel_residual,
121
+ parallel_block_tied_norm=False,
122
+ rotary_emb_fraction=gpt_neox_config.rotary_pct,
123
+ tie_word_embeddings=gpt_neox_config.tie_word_embeddings,
124
+ )
.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/gptj.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+
3
+ import math
4
+ import re
5
+ from collections import OrderedDict
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from transformers import GPT2Config, GPTJConfig
10
+
11
+
12
+ def remap_state_dict_hf_gptj(state_dict, config):
13
+ def key_mapping_layers(key):
14
+ return re.sub(r"^transformer.h.", "transformer.layers.", key)
15
+
16
+ state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
17
+ # Word embedding
18
+ def key_mapping_emb(key):
19
+ return re.sub(r"^transformer.wte.", "transformer.embeddings.word_embeddings.", key)
20
+
21
+ state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
22
+ word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight")
23
+ # It's possible that vocab_size is padded to be a multiple of 8, for example.
24
+ pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
25
+ vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
26
+ state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
27
+ word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
28
+ )
29
+ if getattr(config, "tie_word_embeddings"):
30
+ state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
31
+ else:
32
+ output_embeddings = state_dict.pop("lm_head.weight")
33
+ # It's possible that vocab_size is padded to be a multiple of 8, for example.
34
+ state_dict["lm_head.weight"] = F.pad(
35
+ output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
36
+ )
37
+ output_embeddings_bias = state_dict.pop("lm_head.bias")
38
+ state_dict["lm_head.bias"] = F.pad(
39
+ output_embeddings_bias, (0, vocab_size - output_embeddings_bias.shape[0])
40
+ )
41
+
42
+ # LayerNorm
43
+ def key_mapping_ln(key):
44
+ return re.sub(r"^transformer.layers.(\d+).ln_1.", r"transformer.layers.\1.norm1.", key)
45
+
46
+ state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
47
+
48
+ # MLP
49
+ def key_mapping_mlp(key):
50
+ key = re.sub(
51
+ r"^transformer.layers.(\d+).mlp.fc_in.", r"transformer.layers.\1.mlp.fc1.", key
52
+ )
53
+ key = re.sub(
54
+ r"^transformer.layers.(\d+).mlp.fc_out.", r"transformer.layers.\1.mlp.fc2.", key
55
+ )
56
+ return key
57
+
58
+ state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
59
+
60
+ # Attention
61
+ for l in range(config.n_layer):
62
+ Wq = state_dict.pop(f"transformer.layers.{l}.attn.q_proj.weight")
63
+ Wk = state_dict.pop(f"transformer.layers.{l}.attn.k_proj.weight")
64
+ Wv = state_dict.pop(f"transformer.layers.{l}.attn.v_proj.weight")
65
+ state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat([Wq, Wk, Wv], dim=0)
66
+ # We don't store these biases
67
+ state_dict.pop(f"transformer.layers.{l}.attn.bias")
68
+ state_dict.pop(f"transformer.layers.{l}.attn.masked_bias")
69
+
70
+ def key_mapping_attn(key):
71
+ return re.sub(
72
+ r"^transformer.layers.(\d+).attn.out_proj.",
73
+ r"transformer.layers.\1.mixer.out_proj.",
74
+ key,
75
+ )
76
+
77
+ state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
78
+
79
+ return state_dict
80
+
81
+
82
+ def gptj_config_to_gpt2_config(gptj_config: GPTJConfig) -> GPT2Config:
83
+ headdim = gptj_config.n_embd // gptj_config.n_head
84
+ return GPT2Config(
85
+ vocab_size=gptj_config.vocab_size,
86
+ n_positions=0, # No absolute position embedding
87
+ n_embd=gptj_config.n_embd,
88
+ n_layer=gptj_config.n_layer,
89
+ n_head=gptj_config.n_head,
90
+ n_inner=gptj_config.n_inner,
91
+ activation_function=gptj_config.activation_function,
92
+ resid_pdrop=gptj_config.resid_pdrop,
93
+ embd_pdrop=gptj_config.embd_pdrop,
94
+ attn_pdrop=gptj_config.attn_pdrop,
95
+ layer_norm_epsilon=gptj_config.layer_norm_epsilon,
96
+ initializer_range=gptj_config.initializer_range,
97
+ bos_token_id=gptj_config.bos_token_id,
98
+ eos_token_id=gptj_config.eos_token_id,
99
+ # These are new arguments not in the original GPT2Config
100
+ prenorm=True,
101
+ parallel_block=True,
102
+ parallel_block_tied_norm=True,
103
+ rotary_emb_fraction=gptj_config.rotary_dim / headdim,
104
+ rotary_emb_interleaved=True,
105
+ tie_word_embeddings=False,
106
+ qkv_proj_bias=False,
107
+ out_proj_bias=False,
108
+ lm_head_bias=True,
109
+ )
.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/llama.py ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+
3
+ import json
4
+ import math
5
+ import os
6
+ import re
7
+ from collections import OrderedDict
8
+ from pathlib import Path
9
+ from typing import Dict, List, Union
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from sentencepiece import SentencePieceProcessor
14
+ from transformers import GPT2Config, LlamaConfig
15
+
16
+ from einops import rearrange
17
+
18
+
19
+ def remap_state_dict_meta_llama(
20
+ state_dict: Dict[str, torch.Tensor], config: GPT2Config
21
+ ) -> Dict[str, torch.Tensor]:
22
+ """Convert the state_dict in Meta format to standard GPT format.
23
+
24
+ This function modifies state_dict in place.
25
+ """
26
+
27
+ def key_mapping_layers(key):
28
+ return f"transformer.{key}" if not key.startswith("output.") else key
29
+
30
+ state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
31
+
32
+ # Word embedding
33
+ def key_mapping_emb(key):
34
+ return re.sub(
35
+ r"^transformer.tok_embeddings.", "transformer.embeddings.word_embeddings.", key
36
+ )
37
+
38
+ state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
39
+ word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight")
40
+ # It's possible that vocab_size is padded to be a multiple of 8, for example.
41
+ pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
42
+ vocab_size = (
43
+ math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple) * pad_vocab_size_multiple
44
+ )
45
+ state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
46
+ word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
47
+ )
48
+ if getattr(config, "tie_word_embeddings"):
49
+ state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
50
+ else:
51
+ output_embeddings = state_dict.pop("output.weight")
52
+ # Need to recompute vocab_size since LLaMa shards the word embeddings and output embeddings
53
+ # differently.
54
+ vocab_size = (
55
+ math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple)
56
+ * pad_vocab_size_multiple
57
+ )
58
+ # It's possible that vocab_size is padded to be a multiple of 8, for example.
59
+ state_dict["lm_head.weight"] = F.pad(
60
+ output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
61
+ )
62
+
63
+ # LayerNorm
64
+ def key_mapping_ln(key):
65
+ key = re.sub(r"^transformer.norm.", r"transformer.ln_f.", key)
66
+ key = re.sub(
67
+ r"^transformer.layers.(\d+).attention_norm.",
68
+ r"transformer.layers.\1.norm1.",
69
+ key,
70
+ )
71
+ key = re.sub(r"^transformer.layers.(\d+).ffn_norm.", r"transformer.layers.\1.norm2.", key)
72
+ return key
73
+
74
+ state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
75
+
76
+ # MLP
77
+ for l in range(config.n_layer):
78
+ w1 = state_dict.pop(f"transformer.layers.{l}.feed_forward.w1.weight")
79
+ w3 = state_dict.pop(f"transformer.layers.{l}.feed_forward.w3.weight")
80
+ # Our ordering is different
81
+ state_dict[f"transformer.layers.{l}.mlp.fc1.weight"] = torch.cat([w3, w1], dim=0)
82
+
83
+ def key_mapping_mlp(key):
84
+ return re.sub(
85
+ r"^transformer.layers.(\d+).feed_forward.w2.",
86
+ r"transformer.layers.\1.mlp.fc2.",
87
+ key,
88
+ )
89
+
90
+ state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
91
+
92
+ # Attention
93
+ for l in range(config.n_layer):
94
+ Wq = state_dict.pop(f"transformer.layers.{l}.attention.wq.weight")
95
+ Wk = state_dict.pop(f"transformer.layers.{l}.attention.wk.weight")
96
+ Wv = state_dict.pop(f"transformer.layers.{l}.attention.wv.weight")
97
+ state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat([Wq, Wk, Wv], dim=0)
98
+ # We don't store these
99
+ state_dict.pop(f"transformer.layers.{l}.attention.inner_attention.rope.freqs", None)
100
+
101
+ def key_mapping_attn(key):
102
+ return re.sub(
103
+ r"^transformer.layers.(\d+).attention.wo.",
104
+ r"transformer.layers.\1.mixer.out_proj.",
105
+ key,
106
+ )
107
+
108
+ state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
109
+
110
+ state_dict.pop("transformer.rope.freqs", None)
111
+
112
+ return state_dict
113
+
114
+
115
+ def remap_state_dict_hf_llama(
116
+ state_dict: Dict[str, torch.Tensor], config: GPT2Config
117
+ ) -> Dict[str, torch.Tensor]:
118
+ """Convert the state_dict in Hugging Face format to standard GPT format.
119
+
120
+ This function modifies state_dict in place.
121
+ """
122
+
123
+ # Embedding
124
+ def key_mapping_emb(key):
125
+ return re.sub(r"^model.embed_tokens.", "transformer.embeddings.word_embeddings.", key)
126
+
127
+ state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
128
+ word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight")
129
+ # It's possible that vocab_size is padded to be a multiple of 8, for example.
130
+ pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
131
+ vocab_size = (
132
+ math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple) * pad_vocab_size_multiple
133
+ )
134
+ state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
135
+ word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
136
+ )
137
+
138
+ # LM head
139
+ if getattr(config, "tie_word_embeddings"):
140
+ state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
141
+ else:
142
+ output_embeddings = state_dict.pop("lm_head.weight")
143
+ # Need to recompute vocab_size since LLaMa shards the word embeddings and output embeddings
144
+ # differently.
145
+ vocab_size = (
146
+ math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple)
147
+ * pad_vocab_size_multiple
148
+ )
149
+ # It's possible that vocab_size is padded to be a multiple of 8, for example.
150
+ state_dict["lm_head.weight"] = F.pad(
151
+ output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
152
+ )
153
+
154
+ # MLP
155
+ for l in range(config.n_layer):
156
+ # Fusing weights this way based on difference in the following:
157
+ # https://github.com/huggingface/transformers/blob/b42010bb1d3cbf262d27e0a328661885be46dfdb/src/transformers/models/llama/modeling_llama.py#L220
158
+ # https://github.com/Dao-AILab/flash-attention/blob/c60851a8253257eb970e06a022c82517a8033e8c/flash_attn/modules/mlp.py#L115
159
+ w1 = state_dict.pop(f"model.layers.{l}.mlp.gate_proj.weight")
160
+ w3 = state_dict.pop(f"model.layers.{l}.mlp.up_proj.weight")
161
+ state_dict[f"transformer.layers.{l}.mlp.fc1.weight"] = torch.cat([w3, w1], dim=0)
162
+
163
+ def key_mapping_mlp(key):
164
+ return re.sub(
165
+ r"^model.layers.(\d+).mlp.down_proj.",
166
+ r"transformer.layers.\1.mlp.fc2.",
167
+ key,
168
+ )
169
+
170
+ state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
171
+
172
+ # LayerNorm
173
+ def key_mapping_ln(key):
174
+ key = re.sub(r"^model.norm.", r"transformer.ln_f.", key)
175
+ key = re.sub(
176
+ r"^model.layers.(\d+).input_layernorm.",
177
+ r"transformer.layers.\1.norm1.",
178
+ key,
179
+ )
180
+ key = re.sub(
181
+ r"^model.layers.(\d+).post_attention_layernorm.",
182
+ r"transformer.layers.\1.norm2.",
183
+ key,
184
+ )
185
+ return key
186
+
187
+ state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
188
+
189
+ def inv_permute(w):
190
+ # Inverse of permute implemented in:
191
+ # https://github.com/huggingface/transformers/blob/b42010bb1d3cbf262d27e0a328661885be46dfdb/src/transformers/models/llama/convert_llama_weights_to_hf.py#L114
192
+ return rearrange(
193
+ w, "(h two d) n -> (h d two) n", d=config.n_embd // config.n_head // 2, two=2
194
+ )
195
+
196
+ # Attention
197
+ for l in range(config.n_layer):
198
+ Wq = state_dict.pop(f"model.layers.{l}.self_attn.q_proj.weight")
199
+ Wk = state_dict.pop(f"model.layers.{l}.self_attn.k_proj.weight")
200
+ Wv = state_dict.pop(f"model.layers.{l}.self_attn.v_proj.weight")
201
+
202
+ state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat(
203
+ [inv_permute(Wq), inv_permute(Wk), Wv], dim=0
204
+ )
205
+ # We don't store these
206
+ state_dict.pop(f"model.layers.{l}.self_attn.rotary_emb.inv_freq", None)
207
+
208
+ def key_mapping_attn(key):
209
+ return re.sub(
210
+ r"^model.layers.(\d+).self_attn.o_proj.",
211
+ r"transformer.layers.\1.mixer.out_proj.",
212
+ key,
213
+ )
214
+
215
+ state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
216
+ return state_dict
217
+
218
+
219
+ def inv_remap_state_dict_hf_llama(
220
+ state_dict: Dict[str, torch.Tensor], config: GPT2Config
221
+ ) -> Dict[str, torch.Tensor]:
222
+ """Convert the state_dict in standard GPT format to Hugging Face format.
223
+
224
+ This function is meant to be the inverse of remap_state_dict_hf_llama, up to a
225
+ multiplier pad in the embedding and lm_head. That is if the original embedding
226
+ isn't a multiple of pad_vocab_size_multiple, then
227
+ inv_remap_state_dict_hf_llama(remap_state_dict_hf_llama(state_dict)) != state_dict.
228
+
229
+ This function modifies state_dict in place.
230
+ """
231
+
232
+ # Embedding
233
+ def key_mapping_emb(key):
234
+ return re.sub(r"^transformer.embeddings.word_embeddings.", "model.embed_tokens.", key)
235
+
236
+ state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
237
+ word_embeddings = state_dict.pop("model.embed_tokens.weight")
238
+ pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
239
+ vocab_size = (
240
+ math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple) * pad_vocab_size_multiple
241
+ )
242
+ state_dict["model.embed_tokens.weight"] = F.pad(
243
+ word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
244
+ )
245
+
246
+ # LM head
247
+ if getattr(config, "tie_word_embeddings"):
248
+ state_dict["lm_head.weight"] = state_dict["model.embed_tokens.weight"]
249
+ else:
250
+ output_embeddings = state_dict.pop("lm_head.weight")
251
+ vocab_size = (
252
+ math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple)
253
+ * pad_vocab_size_multiple
254
+ )
255
+ state_dict["lm_head.weight"] = F.pad(
256
+ output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
257
+ )
258
+
259
+ # MLP
260
+ for l in range(config.n_layer):
261
+ w3, w1 = torch.chunk(
262
+ state_dict.pop(f"transformer.layers.{l}.mlp.fc1.weight"), chunks=2, dim=0
263
+ )
264
+ state_dict[f"model.layers.{l}.mlp.gate_proj.weight"] = w1
265
+ state_dict[f"model.layers.{l}.mlp.up_proj.weight"] = w3
266
+
267
+ def key_mapping_mlp(key):
268
+ return re.sub(
269
+ r"^transformer.layers.(\d+).mlp.fc2.",
270
+ r"model.layers.\1.mlp.down_proj.",
271
+ key,
272
+ )
273
+
274
+ state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
275
+
276
+ # LayerNorm
277
+ def key_mapping_ln(key):
278
+ key = re.sub(r"^transformer.ln_f.", r"model.norm.", key)
279
+ key = re.sub(
280
+ r"^transformer.layers.(\d+).norm1.",
281
+ r"model.layers.\1.input_layernorm.",
282
+ key,
283
+ )
284
+ key = re.sub(
285
+ r"^transformer.layers.(\d+).norm2.",
286
+ r"model.layers.\1.post_attention_layernorm.",
287
+ key,
288
+ )
289
+ return key
290
+
291
+ state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
292
+
293
+ def permute(w):
294
+ return rearrange(
295
+ w, "(h d two) n -> (h two d) n", d=config.n_embd // config.n_head // 2, two=2
296
+ )
297
+
298
+ n_head = config.n_head
299
+ n_head_kv = getattr(config, "n_head_kv", n_head)
300
+
301
+ embed_dim = config.hidden_size
302
+ head_dim = embed_dim // n_head
303
+
304
+ q_dim = n_head * head_dim
305
+ k_dim = v_dim = n_head_kv * head_dim
306
+
307
+ # Attention
308
+ for l in range(config.n_layer):
309
+ Wqkv = state_dict.pop(f"transformer.layers.{l}.mixer.Wqkv.weight")
310
+ Wq = Wqkv[:q_dim]
311
+ Wk = Wqkv[q_dim : q_dim + k_dim]
312
+ Wv = Wqkv[q_dim + k_dim : q_dim + k_dim + v_dim]
313
+ state_dict[f"model.layers.{l}.self_attn.q_proj.weight"] = permute(Wq)
314
+ state_dict[f"model.layers.{l}.self_attn.k_proj.weight"] = permute(Wk)
315
+ state_dict[f"model.layers.{l}.self_attn.v_proj.weight"] = Wv
316
+ state_dict.pop(f"transformer.layers.{l}.attention.inner_attention.rope.freqs", None)
317
+
318
+ def key_mapping_attn(key):
319
+ return re.sub(
320
+ r"^transformer.layers.(\d+).mixer.out_proj.",
321
+ r"model.layers.\1.self_attn.o_proj.",
322
+ key,
323
+ )
324
+
325
+ state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
326
+ return state_dict
327
+
328
+
329
+ def config_from_meta_checkpoint(
330
+ checkpoint_path: Union[str, os.PathLike], model_name: str
331
+ ) -> LlamaConfig:
332
+ """Load a LlamaConfig from a checkpoint path."""
333
+ with open(Path(checkpoint_path) / model_name / "params.json") as f:
334
+ params = json.load(f)
335
+ config = LlamaConfig(
336
+ hidden_size=params["dim"],
337
+ intermediate_size=None,
338
+ num_attention_heads=params["n_heads"],
339
+ num_hidden_layers=params["n_layers"],
340
+ rms_norm_eps=params["norm_eps"],
341
+ num_key_value_heads=params.get("n_kv_heads", None),
342
+ )
343
+ multiple_of = params.get("multiple_of", 1)
344
+ ffn_dim_multiplier = params.get("ffn_dim_multiplier", None)
345
+
346
+ # Compute the hidden dimension of the MLP
347
+ # https://github.com/facebookresearch/llama/blob/1a240688810f8036049e8da36b073f63d2ac552c/llama/model.py#L224
348
+ intermediate_size = 4 * config.hidden_size
349
+ # https://github.com/facebookresearch/llama/blob/1a240688810f8036049e8da36b073f63d2ac552c/llama/model.py#L195-L199
350
+ intermediate_size = int(2 * intermediate_size / 3)
351
+ # custom dim factor multiplier
352
+ if ffn_dim_multiplier is not None:
353
+ intermediate_size = int(ffn_dim_multiplier * intermediate_size)
354
+ intermediate_size = multiple_of * ((intermediate_size + multiple_of - 1) // multiple_of)
355
+
356
+ config.intermediate_size = intermediate_size
357
+ if "rope_theta" in params:
358
+ config.rotary_emb_base = params["rope_theta"]
359
+ config.vocab_size = 32000
360
+ # some CodeLLaMa have vocab_size 32000, some 32016
361
+ # Sadly it's not specified in the `params.json` file :(
362
+ tokenizer = Path(checkpoint_path) / model_name / "tokenizer.model"
363
+ if tokenizer.is_file():
364
+ config.vocab_size = SentencePieceProcessor(str(tokenizer)).vocab_size()
365
+ return config
366
+
367
+
368
+ def config_from_hf_checkpoint(
369
+ checkpoint_path: Union[str, os.PathLike], model_name: str
370
+ ) -> LlamaConfig:
371
+ return LlamaConfig.from_pretrained(Path(checkpoint_path) / f"{model_name}-hf" / "config.json")
372
+
373
+
374
+ def config_from_checkpoint(
375
+ checkpoint_path: Union[str, os.PathLike], model_name: str, checkpoint_format="meta"
376
+ ) -> LlamaConfig:
377
+ if checkpoint_format == "meta":
378
+ return config_from_meta_checkpoint(checkpoint_path, model_name)
379
+ else:
380
+ return config_from_hf_checkpoint(checkpoint_path, model_name)
381
+
382
+
383
+ def state_dicts_from_checkpoint(
384
+ checkpoint_path: Union[str, os.PathLike], model_name: str
385
+ ) -> List[dict]:
386
+ # Need to sort, otherwise we mess up the ordering and the weights are wrong
387
+ return [
388
+ torch.load(path, map_location="cpu")
389
+ for path in sorted((Path(checkpoint_path) / model_name).glob("consolidated.*.pth"))
390
+ ]
391
+
392
+
393
+ def llama_config_to_gpt2_config(llama_config: LlamaConfig) -> GPT2Config:
394
+ return GPT2Config(
395
+ vocab_size=llama_config.vocab_size,
396
+ n_positions=0, # No absolute position embedding
397
+ n_embd=llama_config.hidden_size,
398
+ n_layer=llama_config.num_hidden_layers,
399
+ n_head=llama_config.num_attention_heads,
400
+ n_inner=llama_config.intermediate_size,
401
+ activation_function="swiglu", # Hardcode since HF calls it 'silu'
402
+ # Llama doesn't have dropout, idk if it's because they only release the inference code
403
+ resid_pdrop=0.0,
404
+ embd_pdrop=0.0,
405
+ attn_pdrop=0.0,
406
+ layer_norm_epsilon=llama_config.rms_norm_eps,
407
+ initializer_range=llama_config.initializer_range,
408
+ bos_token_id=llama_config.bos_token_id,
409
+ eos_token_id=llama_config.eos_token_id,
410
+ # These are new arguments not in the original GPT2Config
411
+ pad_token_id=llama_config.pad_token_id, # Idk if this does anything
412
+ rms_norm=True,
413
+ rotary_emb_fraction=1.0,
414
+ rotary_emb_interleaved=True,
415
+ tie_word_embeddings=False,
416
+ qkv_proj_bias=False,
417
+ out_proj_bias=False,
418
+ mlp_fc1_bias=False,
419
+ mlp_fc2_bias=False,
420
+ rotary_emb_base=getattr(llama_config, "rotary_emb_base", 10000.0),
421
+ n_head_kv=llama_config.num_key_value_heads,
422
+ )
.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/opt.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+
3
+ import math
4
+ import re
5
+ from collections import OrderedDict
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from transformers import GPT2Config, OPTConfig
10
+
11
+
12
+ def remap_state_dict_hf_opt(state_dict, config):
13
+ def key_mapping_model(key):
14
+ key = re.sub(r"^model.decoder.", "transformer.", key)
15
+ # The OPT-350m model uses '^decoder' instead of '^model.decoder'
16
+ key = re.sub(r"^decoder.", "transformer.", key)
17
+ return key
18
+
19
+ state_dict = OrderedDict((key_mapping_model(k), v) for k, v in state_dict.items())
20
+ # Word embedding and position embedding
21
+ def key_mapping_emb(key):
22
+ key = re.sub(r"^transformer.embed_tokens.", "transformer.embeddings.word_embeddings.", key)
23
+ # The OPT-350m model uses has project_in and project_out
24
+ key = re.sub(r"^transformer.project_in.", "transformer.embeddings.project_in.", key)
25
+ key = re.sub(r"^transformer.project_out.", "project_out.", key)
26
+ key = re.sub(
27
+ r"^transformer.embed_positions.", "transformer.embeddings.position_embeddings.", key
28
+ )
29
+ return key
30
+
31
+ state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
32
+ # OPT uses the first 2 indices of pos_emb for padding tokens
33
+ pos_embeddings = state_dict.pop("transformer.embeddings.position_embeddings.weight")
34
+ state_dict["transformer.embeddings.position_embeddings.weight"] = pos_embeddings[2:]
35
+ word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight")
36
+ # It's possible that vocab_size is padded to be a multiple of 8, for example.
37
+ pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
38
+ vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
39
+ state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
40
+ word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
41
+ )
42
+ state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
43
+
44
+ # LayerNorm
45
+ def key_mapping_ln(key):
46
+ key = re.sub(r"^transformer.final_layer_norm.", r"transformer.ln_f.", key)
47
+ # The OPT-175B checkpoint calls this 'decoder.layer_norm' instead of 'decoder.final_layer_norm'
48
+ key = re.sub(r"^transformer.layer_norm.", r"transformer.ln_f.", key)
49
+ key = re.sub(
50
+ r"^transformer.layers.(\d+).self_attn_layer_norm.", r"transformer.layers.\1.norm1.", key
51
+ )
52
+ key = re.sub(
53
+ r"^transformer.layers.(\d+).final_layer_norm.", r"transformer.layers.\1.norm2.", key
54
+ )
55
+ return key
56
+
57
+ state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
58
+
59
+ # MLP
60
+ def key_mapping_mlp(key):
61
+ return re.sub(
62
+ r"^transformer.layers.(\d+).fc(1|2).", r"transformer.layers.\1.mlp.fc\2.", key
63
+ )
64
+
65
+ state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
66
+
67
+ # Attention
68
+ for l in range(config.n_layer):
69
+ Wq = state_dict.pop(f"transformer.layers.{l}.self_attn.q_proj.weight")
70
+ Wk = state_dict.pop(f"transformer.layers.{l}.self_attn.k_proj.weight")
71
+ Wv = state_dict.pop(f"transformer.layers.{l}.self_attn.v_proj.weight")
72
+ bq = state_dict.pop(f"transformer.layers.{l}.self_attn.q_proj.bias")
73
+ bk = state_dict.pop(f"transformer.layers.{l}.self_attn.k_proj.bias")
74
+ bv = state_dict.pop(f"transformer.layers.{l}.self_attn.v_proj.bias")
75
+ state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat([Wq, Wk, Wv], dim=0)
76
+ state_dict[f"transformer.layers.{l}.mixer.Wqkv.bias"] = torch.cat([bq, bk, bv], dim=0)
77
+
78
+ def key_mapping_attn(key):
79
+ return re.sub(
80
+ r"^transformer.layers.(\d+).self_attn.out_proj.",
81
+ r"transformer.layers.\1.mixer.out_proj.",
82
+ key,
83
+ )
84
+
85
+ state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
86
+
87
+ return state_dict
88
+
89
+
90
+ def opt_config_to_gpt2_config(opt_config: OPTConfig) -> GPT2Config:
91
+ assert opt_config.layerdrop == 0.0
92
+ assert opt_config.layer_norm_elementwise_affine
93
+ word_embed_proj_dim = (
94
+ None
95
+ if opt_config.word_embed_proj_dim == opt_config.hidden_size
96
+ else opt_config.word_embed_proj_dim
97
+ )
98
+ return GPT2Config(
99
+ vocab_size=opt_config.vocab_size,
100
+ n_positions=opt_config.max_position_embeddings,
101
+ n_embd=opt_config.hidden_size,
102
+ n_layer=opt_config.num_hidden_layers,
103
+ n_head=opt_config.num_attention_heads,
104
+ n_inner=opt_config.ffn_dim,
105
+ activation_function=opt_config.activation_function,
106
+ resid_pdrop=opt_config.dropout,
107
+ # HF's implementation of OPT doesn't seem to have embedding dropout
108
+ embd_pdrop=opt_config.dropout,
109
+ attn_pdrop=opt_config.attention_dropout,
110
+ initializer_range=opt_config.init_std,
111
+ bos_token_id=opt_config.bos_token_id,
112
+ eos_token_id=opt_config.eos_token_id,
113
+ # These are new arguments not in the original GPT2Config
114
+ prenorm=opt_config.do_layer_norm_before,
115
+ word_embed_proj_dim=word_embed_proj_dim,
116
+ )
.venv/lib/python3.11/site-packages/xformers/_flash_attn/models/vit.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, Tri Dao.
2
+ # Inspired by / adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
3
+ import math
4
+ import re
5
+ from collections import OrderedDict
6
+ from copy import deepcopy
7
+ from functools import partial
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from einops import rearrange
13
+ from timm.models.helpers import named_apply
14
+ from torch.nn.init import trunc_normal_
15
+ from torchvision.ops import StochasticDepth
16
+
17
+ from flash_attn.layers.patch_embed import PatchEmbed
18
+ from flash_attn.modules.block import Block
19
+ from flash_attn.modules.mha import MHA
20
+ from flash_attn.modules.mlp import FusedMLP, Mlp
21
+
22
+ try:
23
+ from flash_attn.ops.triton.layer_norm import layer_norm_fn
24
+ except ImportError:
25
+ layer_norm_fn = None
26
+
27
+
28
+ def create_mixer_cls(
29
+ num_heads, qkv_bias, attn_drop, use_flash_attn, fused_bias_fc, cross_attn=False
30
+ ):
31
+ mixer_cls = partial(
32
+ MHA,
33
+ num_heads=num_heads,
34
+ cross_attn=cross_attn,
35
+ qkv_proj_bias=qkv_bias,
36
+ dropout=attn_drop,
37
+ fused_bias_fc=fused_bias_fc,
38
+ use_flash_attn=use_flash_attn,
39
+ )
40
+ return mixer_cls
41
+
42
+
43
+ def create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_mlp):
44
+ inner_dim = int(embed_dim * mlp_ratio)
45
+ if not fused_mlp:
46
+ mlp_cls = partial(Mlp, hidden_features=inner_dim, activation=act_layer())
47
+ else:
48
+ mlp_cls = partial(FusedMLP, hidden_features=inner_dim)
49
+ return mlp_cls
50
+
51
+
52
+ def create_block(
53
+ embed_dim,
54
+ num_heads,
55
+ mlp_ratio,
56
+ qkv_bias,
57
+ drop_rate,
58
+ attn_drop_rate,
59
+ drop_path1,
60
+ drop_path2,
61
+ norm_layer,
62
+ act_layer,
63
+ use_flash_attn,
64
+ fused_bias_fc,
65
+ fused_mlp,
66
+ fused_dropout_add_ln,
67
+ layer_idx=None,
68
+ n_layer=None,
69
+ last_layer_subset=False,
70
+ ):
71
+ mixer_cls = create_mixer_cls(
72
+ num_heads,
73
+ qkv_bias,
74
+ attn_drop_rate,
75
+ use_flash_attn,
76
+ fused_bias_fc,
77
+ cross_attn=(last_layer_subset and layer_idx == n_layer - 1),
78
+ )
79
+ mlp_cls = create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_mlp)
80
+ # TD [2022-10-15]: Force residual in fp32 in case of DeepSpeed
81
+ block = Block(
82
+ embed_dim,
83
+ mixer_cls,
84
+ mlp_cls,
85
+ norm_cls=norm_layer,
86
+ prenorm=True,
87
+ resid_dropout1=drop_rate,
88
+ resid_dropout2=drop_rate,
89
+ drop_path1=drop_path1,
90
+ drop_path2=drop_path2,
91
+ fused_dropout_add_ln=fused_dropout_add_ln,
92
+ residual_in_fp32=True,
93
+ )
94
+ return block
95
+
96
+
97
+ class VisionTransformer(nn.Module):
98
+ """Vision Transformer
99
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
100
+ - https://arxiv.org/abs/2010.11929
101
+ """
102
+
103
+ def __init__(
104
+ self,
105
+ img_size=224,
106
+ patch_size=16,
107
+ in_chans=3,
108
+ num_classes=1000,
109
+ global_pool="token",
110
+ embed_dim=768,
111
+ depth=12,
112
+ num_heads=12,
113
+ mlp_ratio=4.0,
114
+ qkv_bias=True,
115
+ init_values=None,
116
+ class_token=True,
117
+ no_embed_class=False,
118
+ pre_norm=False,
119
+ fc_norm=None,
120
+ drop_rate=0.0,
121
+ attn_drop_rate=0.0,
122
+ drop_path_rate=0.0,
123
+ weight_init="",
124
+ embed_layer=PatchEmbed,
125
+ norm_layer=None,
126
+ act_layer=None,
127
+ use_flash_attn=False,
128
+ fused_bias_fc=False,
129
+ fused_mlp=False,
130
+ fused_dropout_add_ln=False,
131
+ ):
132
+ """
133
+ Args:
134
+ img_size (int, tuple): input image size
135
+ patch_size (int, tuple): patch size
136
+ in_chans (int): number of input channels
137
+ num_classes (int): number of classes for classification head
138
+ global_pool (str): type of global pooling for final sequence (default: 'token')
139
+ embed_dim (int): embedding dimension
140
+ depth (int): depth of transformer
141
+ num_heads (int): number of attention heads
142
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
143
+ qkv_bias (bool): enable bias for qkv if True
144
+ init_values: (float): layer-scale init values
145
+ class_token (bool): use class token
146
+ fc_norm (Optional[bool]): pre-fc norm after pool, set if global_pool == 'avg' if None (default: None)
147
+ drop_rate (float): dropout rate
148
+ attn_drop_rate (float): attention dropout rate
149
+ drop_path_rate (float): stochastic depth rate
150
+ weight_init (str): weight init scheme
151
+ embed_layer (nn.Module): patch embedding layer
152
+ norm_layer: (nn.Module): normalization layer
153
+ act_layer: (nn.Module): MLP activation layer
154
+ """
155
+ super().__init__()
156
+ assert global_pool == "token", "Only support pooling with CLS token"
157
+ assert class_token
158
+ assert init_values is None, "LayerScale is not supported yet"
159
+ assert weight_init == ""
160
+ assert fc_norm is None
161
+ # pre_norm seems redundant, as there's a LayerNorm right at the start of each block, idk
162
+ assert not pre_norm
163
+ use_fc_norm = global_pool == "avg" if fc_norm is None else fc_norm
164
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
165
+ act_layer = act_layer or nn.GELU
166
+
167
+ self.num_classes = num_classes
168
+ self.global_pool = global_pool
169
+ self.num_features = (
170
+ self.embed_dim
171
+ ) = embed_dim # num_features for consistency with other models
172
+ self.num_prefix_tokens = 1 if class_token else 0
173
+ self.no_embed_class = no_embed_class
174
+
175
+ patch_embed_extra_kwargs = (
176
+ {"fused_bias_fc": fused_bias_fc} if embed_layer is PatchEmbed else {}
177
+ )
178
+ self.patch_embed = embed_layer(
179
+ img_size=img_size,
180
+ patch_size=patch_size,
181
+ in_chans=in_chans,
182
+ embed_dim=embed_dim,
183
+ bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
184
+ **patch_embed_extra_kwargs,
185
+ )
186
+ num_patches = self.patch_embed.num_patches
187
+
188
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
189
+ embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
190
+ self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * 0.02)
191
+
192
+ dpr = [
193
+ x.item() for x in torch.linspace(0, drop_path_rate, depth)
194
+ ] # stochastic depth decay rule
195
+
196
+ # We change the order of dropout, residual and layer norm:
197
+ # Instead of LN -> Attn / MLP -> Dropout -> Add, we do:
198
+ # Dropout -> Add -> LN -> Attn / MLP, returning both the residual branch (output of Add) and
199
+ # the main branch (output of MLP). The model definition is unchanged, but the mapping of the
200
+ # nn.Dropout probabilities are changed.
201
+ # This is for performance reason: we can fuse dropout + add + layer_norm.
202
+ self.blocks = nn.ModuleList(
203
+ [
204
+ create_block(
205
+ embed_dim,
206
+ num_heads,
207
+ mlp_ratio,
208
+ qkv_bias,
209
+ drop_rate,
210
+ attn_drop_rate,
211
+ drop_path1=dpr[i - 1] if i > 0 else 0.0,
212
+ drop_path2=dpr[i],
213
+ norm_layer=norm_layer,
214
+ act_layer=act_layer,
215
+ use_flash_attn=use_flash_attn,
216
+ fused_bias_fc=fused_bias_fc,
217
+ fused_mlp=fused_mlp,
218
+ fused_dropout_add_ln=fused_dropout_add_ln,
219
+ layer_idx=i,
220
+ n_layer=depth,
221
+ last_layer_subset=(global_pool == "token"),
222
+ )
223
+ for i in range(depth)
224
+ ]
225
+ )
226
+
227
+ self.dropout = nn.Dropout(p=drop_rate)
228
+ self.drop_path = StochasticDepth(p=dpr[-1], mode="row")
229
+ self.norm = norm_layer(embed_dim)
230
+
231
+ self.fused_dropout_add_ln = fused_dropout_add_ln
232
+ if self.fused_dropout_add_ln and layer_norm_fn is None:
233
+ raise ImportError("Triton is not installed")
234
+
235
+ # Classifier Head
236
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
237
+
238
+ self.init_weights(weight_init)
239
+
240
+ def init_weights(self, mode=""):
241
+ assert mode == ""
242
+ trunc_normal_(self.pos_embed, std=0.02)
243
+ if self.cls_token is not None:
244
+ nn.init.normal_(self.cls_token, std=1e-6)
245
+ named_apply(init_weights_vit_timm, self)
246
+
247
+ def _init_weights(self, m):
248
+ # this fn left here for compat with downstream users
249
+ init_weights_vit_timm(m)
250
+
251
+ @torch.jit.ignore
252
+ def no_weight_decay(self):
253
+ return {"pos_embed", "cls_token"}
254
+
255
+ def _pos_embed(self, x):
256
+ if self.no_embed_class:
257
+ # deit-3, updated JAX (big vision)
258
+ # position embedding does not overlap with class token, add then concat
259
+ x = x + self.pos_embed
260
+ if self.cls_token is not None:
261
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
262
+ else:
263
+ # original timm, JAX, and deit vit impl
264
+ # pos_embed has entry for class token, concat then add
265
+ if self.cls_token is not None:
266
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
267
+ x = x + self.pos_embed
268
+ return x
269
+
270
+ def forward_features(self, x, all_tokens=True):
271
+ """
272
+ If all_tokens==False and self.global_pool == 'token', we only return the features for the
273
+ cls token.
274
+ """
275
+ x = self.patch_embed(x)
276
+ hidden_states = self._pos_embed(x)
277
+ residual = None
278
+ if self.global_pool != "token" or all_tokens:
279
+ # if True:
280
+ for block in self.blocks:
281
+ hidden_states, residual = block(hidden_states, residual)
282
+ else:
283
+ for block in self.blocks[:-1]:
284
+ hidden_states, residual = block(hidden_states, residual)
285
+ # For the last layer, we only want the 1st token of the output. So we do cross-attention
286
+ # where the query is the 1st token and the key/value is the whole sequence.
287
+ hidden_states, residual = self.blocks[-1](
288
+ hidden_states, residual, mixer_subset=slice(0, 1)
289
+ )
290
+ if not self.fused_dropout_add_ln:
291
+ residual = self.drop_path(self.dropout(hidden_states)) + residual
292
+ hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
293
+ else:
294
+ if self.drop_path.p == 0 or not self.training:
295
+ rowscale = None
296
+ else:
297
+ rowscale = self.drop_path(
298
+ torch.ones(
299
+ hidden_states.shape[:-1],
300
+ device=hidden_states.device,
301
+ dtype=hidden_states.dtype,
302
+ )
303
+ )
304
+ # Set prenorm=False here since we don't need to the residual
305
+ hidden_states = layer_norm_fn(
306
+ hidden_states,
307
+ self.norm.weight,
308
+ self.norm.bias,
309
+ residual=residual,
310
+ eps=self.norm.eps,
311
+ dropout_p=self.dropout.p if self.training else 0.0,
312
+ rowscale=rowscale,
313
+ prenorm=False,
314
+ )
315
+ return hidden_states
316
+
317
+ def forward_head(self, x, pre_logits: bool = False):
318
+ if self.global_pool:
319
+ x = x[:, self.num_prefix_tokens :].mean(dim=1) if self.global_pool == "avg" else x[:, 0]
320
+ return x if pre_logits else self.head(x)
321
+
322
+ def forward(self, x):
323
+ x = self.forward_features(x, all_tokens=False)
324
+ x = self.forward_head(x)
325
+ return x
326
+
327
+ def load_state_dict(self, state_dict, strict=True):
328
+ patch_embed_weight = state_dict["patch_embed.proj.weight"]
329
+ if patch_embed_weight.dim() == 4:
330
+ # convert from Conv2d to Linear
331
+ state_dict["patch_embed.proj.weight"] = rearrange(
332
+ patch_embed_weight, "o c h w -> o (c h w)"
333
+ )
334
+
335
+ def key_mapping_attn(key):
336
+ key = re.sub(r"^blocks.(\d+).attn.qkv.", r"blocks.\1.mixer.Wqkv.", key)
337
+ key = re.sub(r"^blocks.(\d+).attn.proj.", r"blocks.\1.mixer.out_proj.", key)
338
+ return key
339
+
340
+ state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
341
+ n_layer = len(self.blocks)
342
+ # Convert from Wqkv to Wq and Wkv for cross attention (last layer)
343
+ if (
344
+ self.blocks[-1].mixer.cross_attn
345
+ and f"blocks.{n_layer - 1}.mixer.Wqkv.weight" in state_dict
346
+ ):
347
+ Wqkv = state_dict.pop(f"blocks.{n_layer - 1}.mixer.Wqkv.weight")
348
+ bqkv = state_dict.pop(f"blocks.{n_layer - 1}.mixer.Wqkv.bias")
349
+ state_dict[f"blocks.{n_layer - 1}.mixer.Wq.weight"] = Wqkv[: self.embed_dim]
350
+ state_dict[f"blocks.{n_layer - 1}.mixer.Wkv.weight"] = Wqkv[self.embed_dim :]
351
+ state_dict[f"blocks.{n_layer - 1}.mixer.Wq.bias"] = bqkv[: self.embed_dim]
352
+ state_dict[f"blocks.{n_layer - 1}.mixer.Wkv.bias"] = bqkv[self.embed_dim :]
353
+ return super().load_state_dict(state_dict, strict=strict)
354
+
355
+
356
+ def init_weights_vit_timm(module: nn.Module, name: str = ""):
357
+ """ViT weight initialization, original timm impl (for reproducibility)"""
358
+ if isinstance(module, nn.Linear):
359
+ trunc_normal_(module.weight, std=0.02)
360
+ if module.bias is not None:
361
+ nn.init.zeros_(module.bias)
362
+ elif hasattr(module, "init_weights"):
363
+ module.init_weights()
364
+
365
+
366
+ def vit_base_patch16_224(pretrained=False, **kwargs):
367
+ """ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
368
+ ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
369
+ """
370
+ assert not pretrained
371
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
372
+ model = VisionTransformer(**model_kwargs)
373
+ return model
.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (197 Bytes). View file
 
.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/__pycache__/activations.cpython-311.pyc ADDED
Binary file (6.86 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/__pycache__/fused_dense.cpython-311.pyc ADDED
Binary file (30.4 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/__pycache__/layer_norm.cpython-311.pyc ADDED
Binary file (22.6 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/__pycache__/rms_norm.cpython-311.pyc ADDED
Binary file (5.19 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/activations.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/model/layers/activations.py
2
+ import math
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ # 1/sqrt(2*pi)-> 0.3989423
9
+ # 1/sqrt(2) -> 0.70710678
10
+ # sqrt(2/pi) -> 0.79788456
11
+
12
+ # this function is tanh approximation of gelu
13
+ # actual gelu is:
14
+ # x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
15
+ @torch.jit.script
16
+ def bias_gelu(y, bias):
17
+ x = bias + y
18
+ return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=y.dtype)
19
+
20
+
21
+ # gradient of tanh approximation of gelu
22
+ # gradient of actual gelu is:
23
+ # 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
24
+ @torch.jit.script
25
+ def bias_gelu_back(g, y, bias):
26
+ """Assume that y has shape (B, D) and bias has shape (D)"""
27
+ x = bias + y
28
+ tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
29
+ # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
30
+ ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (
31
+ 1 + tanh_out
32
+ )
33
+ grad_y = ff * g
34
+ return grad_y.to(dtype=y.dtype), grad_y.sum(dim=(0), dtype=bias.dtype)
35
+
36
+
37
+ class GeLUFunction(torch.autograd.Function):
38
+ @staticmethod
39
+ # bias is an optional argument
40
+ def forward(ctx, input, bias):
41
+ ctx.save_for_backward(input, bias)
42
+ return bias_gelu(input, bias)
43
+
44
+ @staticmethod
45
+ def backward(ctx, grad_output):
46
+ input, bias = ctx.saved_tensors
47
+ tmp = bias_gelu_back(grad_output, input, bias)
48
+ return tmp, tmp
49
+
50
+
51
+ bias_gelu_impl = GeLUFunction.apply
52
+
53
+ # this function is tanh approximation of gelu
54
+ # actual gelu is:
55
+ # x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
56
+ @torch.jit.script
57
+ def gelu_fwd(x):
58
+ return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=x.dtype)
59
+
60
+
61
+ # gradient of tanh approximation of gelu
62
+ # gradient of actual gelu is:
63
+ # 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
64
+ @torch.jit.script
65
+ def gelu_bwd(g, x):
66
+ tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
67
+ # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
68
+ ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (
69
+ 1 + tanh_out
70
+ )
71
+ return (ff * g).to(dtype=x.dtype)
72
+
73
+
74
+ class FastGeLUFunction(torch.autograd.Function):
75
+ @staticmethod
76
+ # bias is an optional argument
77
+ def forward(ctx, input):
78
+ ctx.save_for_backward(input)
79
+ return gelu_fwd(input)
80
+
81
+ @staticmethod
82
+ def backward(ctx, grad_output):
83
+ (input,) = ctx.saved_tensors
84
+ tmp = gelu_bwd(grad_output, input)
85
+ return tmp
86
+
87
+
88
+ fast_gelu_impl = FastGeLUFunction.apply
89
+
90
+
91
+ @torch.jit.script
92
+ def relu_bwd(g, x):
93
+ return torch.where(x >= 0, g, 0.0).to(dtype=x.dtype)
94
+
95
+
96
+ @torch.jit.script
97
+ def sqrelu_fwd(x):
98
+ r = F.relu(x)
99
+ return (r * r).to(dtype=x.dtype)
100
+
101
+
102
+ @torch.jit.script
103
+ def sqrelu_bwd(g, x):
104
+ return (2.0 * g * F.relu(x)).to(dtype=x.dtype)
105
+
106
+
107
+ swiglu_fwd_codestring = """
108
+ template <typename T> T swiglu_fwd(T x, T y) {
109
+ return float(x) * float(y) / (1.0f + ::exp(-float(x)));
110
+ }
111
+ """
112
+ swiglu_bwd_codestring = """
113
+ template <typename T> T swiglu_bwd(T x, T y, T g, T& dx, T& dy) {
114
+ float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x)));
115
+ dx = x_sigmoid * (1 + float(x) * (1.0f - x_sigmoid)) * float(g) * float(y);
116
+ dy = float(x) * x_sigmoid * float(g);
117
+ }
118
+ """
119
+ swiglu_fwd = torch.cuda.jiterator._create_jit_fn(swiglu_fwd_codestring)
120
+ swiglu_bwd = torch.cuda.jiterator._create_multi_output_jit_fn(swiglu_bwd_codestring, num_outputs=2)
121
+
122
+
123
+ class SwiGLUFunction(torch.autograd.Function):
124
+
125
+ @staticmethod
126
+ def forward(ctx, x, y):
127
+ ctx.save_for_backward(x, y)
128
+ return swiglu_fwd(x, y)
129
+
130
+ @staticmethod
131
+ def backward(ctx, dout):
132
+ x, y = ctx.saved_tensors
133
+ return swiglu_bwd(x, y, dout)
134
+
135
+ swiglu = SwiGLUFunction.apply
.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/fused_dense.py ADDED
@@ -0,0 +1,688 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+ # Inspired by https://github.com/NVIDIA/apex/blob/master/apex/fused_dense/fused_dense.py
3
+ # We make it work with pytorch amp and with bfloat16.
4
+ # The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py
5
+ from functools import partial
6
+ from typing import Optional
7
+
8
+ # import fused_dense_cuda # from apex
9
+ import fused_dense_lib as fused_dense_cuda
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from torch import Tensor
14
+ from torch.cuda.amp import custom_bwd, custom_fwd
15
+ from torch.distributed import ProcessGroup
16
+
17
+ from flash_attn.ops.activations import gelu_bwd, relu_bwd, sqrelu_bwd, sqrelu_fwd
18
+ from flash_attn.utils.distributed import (
19
+ all_gather_raw,
20
+ all_reduce,
21
+ all_reduce_raw,
22
+ reduce_scatter,
23
+ reduce_scatter_raw,
24
+ )
25
+
26
+
27
+ class FusedDenseFunc(torch.autograd.Function):
28
+ @staticmethod
29
+ @custom_fwd
30
+ def forward(
31
+ ctx, x, weight, bias, return_residual=False, process_group=None, sequence_parallel=True
32
+ ):
33
+ """
34
+ If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
35
+ with sequence parallelism: we do an all_gather_raw of x before doing the matmul.
36
+ """
37
+ ctx.compute_weight_gradient = weight.requires_grad
38
+ ctx.return_residual = return_residual
39
+ ctx.process_group = process_group
40
+ ctx.sequence_parallel = sequence_parallel
41
+
42
+ if torch.is_autocast_enabled():
43
+ x = x.to(dtype=torch.get_autocast_gpu_dtype())
44
+ x = x.contiguous()
45
+ if process_group is not None and sequence_parallel:
46
+ # We want to kick off the all_gather early, before weight dtype conversion
47
+ total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
48
+ else:
49
+ total_x = x
50
+
51
+ if torch.is_autocast_enabled():
52
+ weight = weight.to(dtype=torch.get_autocast_gpu_dtype())
53
+ bias = bias.to(dtype=torch.get_autocast_gpu_dtype()) if bias is not None else None
54
+ weight = weight.contiguous()
55
+ if process_group is not None and sequence_parallel:
56
+ handle_x.wait()
57
+ batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
58
+ batch_dim = batch_shape.numel()
59
+ # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
60
+ if min(batch_dim, n, *weight.shape) > 65535 * 32:
61
+ raise RuntimeError("fused_dense only supports matrix dims <= 2M")
62
+ output = F.linear(total_x, weight, bias)
63
+ if ctx.compute_weight_gradient:
64
+ ctx.save_for_backward(x, weight)
65
+ else:
66
+ ctx.save_for_backward(weight)
67
+ return output if not return_residual else (output, x)
68
+
69
+ @staticmethod
70
+ @custom_bwd
71
+ def backward(ctx, grad_output, *args):
72
+ grad_output = grad_output.contiguous()
73
+ if ctx.return_residual:
74
+ (grad_input,) = args
75
+ grad_input = grad_input.contiguous()
76
+ process_group = ctx.process_group
77
+ sequence_parallel = ctx.sequence_parallel
78
+ if ctx.compute_weight_gradient:
79
+ x, weight = ctx.saved_tensors
80
+ if process_group is not None and sequence_parallel:
81
+ total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
82
+ else:
83
+ total_x = x
84
+ else:
85
+ (weight,) = ctx.saved_tensors
86
+ total_x = None
87
+ batch_shape = grad_output.shape[:-1]
88
+ batch_dim = batch_shape.numel()
89
+ grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
90
+ if ctx.needs_input_grad[0]:
91
+ if not ctx.return_residual:
92
+ grad_input = F.linear(grad_output, weight.t())
93
+ else:
94
+ grad_input = torch.addmm(
95
+ grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_output, weight
96
+ )
97
+ grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
98
+ if process_group is not None:
99
+ reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
100
+ grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True)
101
+ else:
102
+ grad_input = None
103
+ if ctx.needs_input_grad[1]:
104
+ assert ctx.compute_weight_gradient
105
+ if process_group is not None and sequence_parallel:
106
+ handle_x.wait()
107
+ grad_weight, grad_bias = fused_dense_cuda.linear_bias_wgrad(
108
+ total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2]
109
+ )
110
+ else:
111
+ grad_weight = None
112
+ grad_bias = grad_output if ctx.needs_input_grad[2] else None
113
+ if process_group is not None and ctx.needs_input_grad[0]:
114
+ handle_grad_input.wait()
115
+ return grad_input, grad_weight, grad_bias, None, None, None
116
+
117
+
118
+ def fused_dense_func(
119
+ x: Tensor,
120
+ weight: Tensor,
121
+ bias: Optional[Tensor] = None,
122
+ return_residual: bool = False,
123
+ process_group: Optional[ProcessGroup] = None,
124
+ sequence_parallel: bool = True,
125
+ ):
126
+ dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or (
127
+ x.dtype == torch.float32 and torch.is_autocast_enabled()
128
+ )
129
+ if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible:
130
+ return FusedDenseFunc.apply(
131
+ x, weight, bias, return_residual, process_group, sequence_parallel
132
+ )
133
+ else:
134
+ assert process_group is None
135
+ out = F.linear(x, weight, bias)
136
+ return out if not return_residual else (out, x)
137
+
138
+
139
+ class FusedDense(nn.Linear):
140
+ def __init__(
141
+ self,
142
+ in_features: int,
143
+ out_features: int,
144
+ bias: bool = True,
145
+ return_residual: bool = False,
146
+ device=None,
147
+ dtype=None,
148
+ ) -> None:
149
+ super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype)
150
+ self.return_residual = return_residual
151
+
152
+ def forward(self, x, process_group=None):
153
+ """
154
+ If process_group is not None, we're doing Tensor Parallel with sequence parallelism:
155
+ we do an all_gather of x before doing the matmul.
156
+ """
157
+ return fused_dense_func(
158
+ x,
159
+ self.weight,
160
+ self.bias,
161
+ return_residual=self.return_residual,
162
+ process_group=process_group,
163
+ )
164
+
165
+
166
+ class ColumnParallelLinear(nn.Linear):
167
+ def __init__(
168
+ self,
169
+ in_features: int,
170
+ out_features: int,
171
+ process_group: ProcessGroup,
172
+ bias: bool = True,
173
+ sequence_parallel=True,
174
+ multiple_of=1,
175
+ device=None,
176
+ dtype=None,
177
+ ) -> None:
178
+ world_size = torch.distributed.get_world_size(process_group)
179
+ if out_features % multiple_of:
180
+ raise ValueError(f"out_features ({out_features}) must be a multiple of {multiple_of}")
181
+ multiple = out_features // multiple_of
182
+ # We want to split @multiple across world_size, but it could be an uneven split
183
+ div = multiple // world_size
184
+ mod = multiple % world_size
185
+ # The first @mod ranks get @div + 1 copies, the rest get @div copies
186
+ local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
187
+ super().__init__(
188
+ in_features, local_multiple * multiple_of, bias=bias, device=device, dtype=dtype
189
+ )
190
+ self.process_group = process_group
191
+ self.sequence_parallel = sequence_parallel
192
+
193
+ def forward(self, x):
194
+ # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
195
+ # we do an all_gather of x before doing the matmul.
196
+ # If not, then the input is already gathered.
197
+ return fused_dense_func(
198
+ x,
199
+ self.weight,
200
+ self.bias,
201
+ process_group=self.process_group,
202
+ sequence_parallel=self.sequence_parallel,
203
+ )
204
+
205
+
206
+ class RowParallelLinear(nn.Linear):
207
+ def __init__(
208
+ self,
209
+ in_features: int,
210
+ out_features: int,
211
+ process_group: ProcessGroup,
212
+ bias: bool = True,
213
+ sequence_parallel=True,
214
+ multiple_of=1,
215
+ device=None,
216
+ dtype=None,
217
+ ) -> None:
218
+ world_size = torch.distributed.get_world_size(process_group)
219
+ rank = torch.distributed.get_rank(process_group)
220
+ if in_features % multiple_of:
221
+ raise ValueError(f"in_features ({in_features}) must be a multiple of {multiple_of}")
222
+ multiple = in_features // multiple_of
223
+ # We want to split @multiple across world_size, but it could be an uneven split
224
+ div = multiple // world_size
225
+ mod = multiple % world_size
226
+ # The first @mod ranks get @div + 1 copies, the rest get @div copies
227
+ local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
228
+ # Only rank 0 will have bias
229
+ super().__init__(
230
+ local_multiple * multiple_of,
231
+ out_features,
232
+ bias=bias and rank == 0,
233
+ device=device,
234
+ dtype=dtype,
235
+ )
236
+ self.process_group = process_group
237
+ self.sequence_parallel = sequence_parallel
238
+
239
+ def forward(self, x):
240
+ """
241
+ We're doing Tensor Parallel with sequence parallelism: we do the matmul and then
242
+ a reduce_scatter of the result.
243
+ """
244
+ out = fused_dense_func(x, self.weight, self.bias)
245
+ reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
246
+ return reduce_fn(out, self.process_group)
247
+
248
+
249
+ class FusedMLPFunc(torch.autograd.Function):
250
+ @staticmethod
251
+ @custom_fwd
252
+ def forward(
253
+ ctx,
254
+ x,
255
+ weight1,
256
+ bias1,
257
+ weight2,
258
+ bias2,
259
+ activation="gelu_approx",
260
+ save_pre_act=True,
261
+ return_residual=False,
262
+ checkpoint_lvl=0,
263
+ heuristic=0,
264
+ process_group=None,
265
+ sequence_parallel=True,
266
+ ):
267
+ """
268
+ If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
269
+ with sequence parallelism: we do an all_gather of x before doing the matmul.
270
+ If sequence_parallel=False, then the input is already gathered.
271
+
272
+ checkpoint_lvl:
273
+ 0: no recomputation in the bwd
274
+ 1: recompute gelu_out / relu_out in the bwd
275
+ 2: recompute pre_act and gelu_out / relu_out in the bwd
276
+ """
277
+ assert -1 <= heuristic <= 4
278
+ assert activation in ["gelu_approx", "relu", "sqrelu"]
279
+ if activation == "sqrelu":
280
+ assert heuristic == -1
281
+ if not save_pre_act:
282
+ checkpoint_lvl = 2
283
+ assert checkpoint_lvl in [0, 1, 2]
284
+ ctx.return_residual = return_residual
285
+ ctx.process_group = process_group
286
+ ctx.sequence_parallel = sequence_parallel
287
+ ctx.checkpoint_lvl = checkpoint_lvl
288
+ ctx.activation = activation
289
+ ctx.heuristic = heuristic
290
+
291
+ if torch.is_autocast_enabled():
292
+ x = x.to(dtype=torch.get_autocast_gpu_dtype())
293
+ x = x.contiguous()
294
+ if process_group is not None and sequence_parallel:
295
+ # We want to kick off the all_gather early, before weight dtype conversion
296
+ total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
297
+ else:
298
+ total_x = x
299
+
300
+ if torch.is_autocast_enabled():
301
+ dtype = torch.get_autocast_gpu_dtype()
302
+ weight1, weight2 = [a.to(dtype=dtype) for a in [weight1, weight2]]
303
+ bias1 = bias1.to(dtype=dtype) if bias1 is not None else None
304
+ bias2 = bias2.to(dtype=dtype) if bias2 is not None else None
305
+ weight1 = weight1.contiguous()
306
+ bias1 = bias1.contiguous() if bias1 is not None else None
307
+ weight2 = weight2.contiguous()
308
+ bias2 = bias2.contiguous() if bias2 is not None else None
309
+ if process_group is not None and sequence_parallel:
310
+ handle_x.wait()
311
+ batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
312
+ batch_dim = batch_shape.numel()
313
+ # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
314
+ if min(batch_dim, n, *weight1.shape, *weight2.shape) > 65535 * 32:
315
+ raise RuntimeError("fused_dense only supports matrix dims <= 2M")
316
+ if heuristic == -1:
317
+ pre_act = F.linear(total_x, weight1, bias1)
318
+ activation_fn = (
319
+ partial(F.gelu, approximate="tanh")
320
+ if activation == "gelu_approx"
321
+ else (sqrelu_fwd if activation == "sqrelu" else F.relu)
322
+ )
323
+ with torch.jit.fuser("fuser2"):
324
+ output1 = activation_fn(pre_act)
325
+ # This is before adding bias1
326
+ # pre_act = F.linear(total_x.reshape(batch_dim, n), weight1)
327
+ # with torch.jit.fuser('fuser2'):
328
+ # output1 = bias_gelu(pre_act, bias1)
329
+ else:
330
+ is_gelu = activation == "gelu_approx"
331
+ output1, *rest = fused_dense_cuda.linear_act_forward(
332
+ total_x.reshape(batch_dim, n), weight1, bias1, is_gelu, save_pre_act, heuristic
333
+ )
334
+ if save_pre_act:
335
+ pre_act = rest[0]
336
+ output2 = F.linear(output1, weight2, bias2)
337
+ if checkpoint_lvl == 0 or (checkpoint_lvl == 1 and activation == "relu"):
338
+ # For RELU the pre_act is very small (just a bit-mask) so we just save it
339
+ ctx.save_for_backward(x, weight1, weight2, pre_act, output1)
340
+ elif checkpoint_lvl == 1:
341
+ ctx.save_for_backward(x, weight1, weight2, pre_act)
342
+ elif checkpoint_lvl == 2:
343
+ ctx.save_for_backward(x, weight1, weight2, bias1)
344
+ output2 = output2.reshape(*batch_shape, output2.shape[-1])
345
+ return output2 if not return_residual else (output2, x)
346
+
347
+ @staticmethod
348
+ @custom_bwd
349
+ def backward(ctx, grad_output, *args):
350
+ grad_output = grad_output.contiguous()
351
+ checkpoint_lvl = ctx.checkpoint_lvl
352
+ activation = ctx.activation
353
+ activation_fn = (
354
+ partial(F.gelu, approximate="tanh")
355
+ if activation == "gelu_approx"
356
+ else (sqrelu_fwd if activation == "sqrelu" else F.relu)
357
+ )
358
+ if ctx.return_residual:
359
+ (grad_input,) = args
360
+ grad_input = grad_input.contiguous()
361
+ process_group = ctx.process_group
362
+ sequence_parallel = ctx.sequence_parallel
363
+ x, weight1, weight2, *rest = ctx.saved_tensors
364
+ if process_group is None or not sequence_parallel:
365
+ total_x = x
366
+ batch_shape = grad_output.shape[:-1]
367
+ batch_dim = batch_shape.numel()
368
+ if checkpoint_lvl in [0, 1]:
369
+ if process_group is not None and sequence_parallel:
370
+ total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
371
+ if checkpoint_lvl == 0 or (checkpoint_lvl == 1 and activation == "relu"):
372
+ pre_act, output1 = rest
373
+ elif checkpoint_lvl == 1:
374
+ (pre_act,) = rest
375
+ with torch.jit.fuser("fuser2"):
376
+ output1 = activation_fn(pre_act)
377
+ elif checkpoint_lvl == 2:
378
+ (bias1,) = rest
379
+ if process_group is not None and sequence_parallel:
380
+ total_x, _ = all_gather_raw(x, process_group)
381
+ if ctx.heuristic == -1:
382
+ pre_act = F.linear(total_x, weight1, bias1)
383
+ with torch.jit.fuser("fuser2"):
384
+ output1 = activation_fn(pre_act)
385
+ else:
386
+ output1, pre_act = fused_dense_cuda.linear_act_forward(
387
+ total_x.reshape(batch_dim, total_x.shape[-1]),
388
+ weight1,
389
+ bias1,
390
+ activation == "gelu_approx",
391
+ True,
392
+ ctx.heuristic,
393
+ )
394
+
395
+ grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
396
+ output1 = output1.reshape(batch_dim, output1.shape[-1])
397
+ pre_act = pre_act.reshape(batch_dim, pre_act.shape[-1])
398
+ if ctx.needs_input_grad[3]:
399
+ grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(
400
+ output1, grad_output, ctx.needs_input_grad[4]
401
+ )
402
+ else:
403
+ grad_weight2 = None
404
+ grad_bias2 = grad_output if ctx.needs_input_grad[4] else None
405
+ if ctx.heuristic == -1:
406
+ # grad_pre_act = matmul_dgelu(grad_output, weight2, pre_act)
407
+ grad_output1 = F.linear(grad_output, weight2.t())
408
+ activation_grad_fn = (
409
+ gelu_bwd
410
+ if activation == "gelu_approx"
411
+ else (sqrelu_bwd if activation == "sqrelu" else relu_bwd)
412
+ )
413
+ with torch.jit.fuser("fuser2"):
414
+ grad_pre_act = activation_grad_fn(grad_output1, pre_act)
415
+ else:
416
+ # The cublasLt epilogue has to compute both gelu/relu grad and bias grad, we can't
417
+ # just compute gelu/relu grad
418
+ grad_pre_act, grad_bias1 = fused_dense_cuda.bias_act_linear_dgrad_bgrad(
419
+ weight2, grad_output, pre_act, activation == "gelu_approx", ctx.heuristic
420
+ )
421
+ if not ctx.needs_input_grad[2]:
422
+ grad_bias1 = None
423
+ if ctx.needs_input_grad[0]:
424
+ if not ctx.return_residual:
425
+ grad_input = F.linear(grad_pre_act, weight1.t())
426
+ else:
427
+ grad_input = torch.addmm(
428
+ grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_pre_act, weight1
429
+ )
430
+ grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
431
+ if process_group is not None:
432
+ reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
433
+ grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True)
434
+ else:
435
+ grad_input = None
436
+ if ctx.heuristic == -1:
437
+ if ctx.needs_input_grad[1]:
438
+ if process_group is not None and sequence_parallel and checkpoint_lvl != 2:
439
+ handle_x.wait()
440
+ grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_wgrad(
441
+ total_x.reshape(batch_dim, total_x.shape[-1]),
442
+ grad_pre_act,
443
+ ctx.needs_input_grad[2],
444
+ )
445
+ else:
446
+ grad_weight1 = None
447
+ grad_bias1 = grad_pre_act if ctx.needs_input_grad[2] else None
448
+ else:
449
+ if ctx.needs_input_grad[1]:
450
+ if process_group is not None and sequence_parallel and checkpoint_lvl != 2:
451
+ handle_x.wait()
452
+ grad_weight1 = F.linear(
453
+ grad_pre_act.t(), total_x.reshape(batch_dim, total_x.shape[-1]).t()
454
+ )
455
+ else:
456
+ grad_weight1 = None
457
+ if process_group is not None and ctx.needs_input_grad[0]:
458
+ handle_grad_input.wait()
459
+ return (
460
+ grad_input,
461
+ grad_weight1,
462
+ grad_bias1,
463
+ grad_weight2,
464
+ grad_bias2,
465
+ None,
466
+ None,
467
+ None,
468
+ None,
469
+ None,
470
+ None,
471
+ None,
472
+ )
473
+
474
+
475
+ def fused_mlp_func(
476
+ x: Tensor,
477
+ weight1: Tensor,
478
+ weight2: Tensor,
479
+ bias1: Optional[Tensor] = None,
480
+ bias2: Optional[Tensor] = None,
481
+ activation: str = "gelu_approx",
482
+ save_pre_act: bool = True,
483
+ return_residual: bool = False,
484
+ checkpoint_lvl: int = 0,
485
+ heuristic: int = 0,
486
+ process_group: Optional[ProcessGroup] = None,
487
+ sequence_parallel: bool = True,
488
+ ):
489
+ assert activation in ["gelu_approx", "relu", "sqrelu"]
490
+ dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or (
491
+ x.dtype == torch.float32 and torch.is_autocast_enabled()
492
+ )
493
+ # If we save pre-activation, dimension must be divisible by 128 (relu) or 8 (gelu)
494
+ dim_eligible = not save_pre_act or (x.shape[-1] % (128 if activation == "relu" else 8) == 0)
495
+ if (
496
+ x.is_cuda
497
+ and weight1.is_cuda
498
+ and weight2.is_cuda
499
+ and (bias1 is None or bias1.is_cuda)
500
+ and (bias2 is None or bias2.is_cuda)
501
+ and dtype_eligible
502
+ and dim_eligible
503
+ ):
504
+ return FusedMLPFunc.apply(
505
+ x,
506
+ weight1,
507
+ bias1,
508
+ weight2,
509
+ bias2,
510
+ activation,
511
+ save_pre_act,
512
+ return_residual,
513
+ checkpoint_lvl,
514
+ heuristic,
515
+ process_group,
516
+ sequence_parallel,
517
+ )
518
+ else:
519
+ assert process_group is None
520
+ pre_act = F.linear(x, weight1, bias1)
521
+ activation_fn = (
522
+ partial(F.gelu, approximate="tanh")
523
+ if activation == "gelu_approx"
524
+ else partial(F.relu, inplace=True)
525
+ )
526
+ output1 = activation_fn(pre_act)
527
+ output2 = F.linear(output1, weight2, bias2)
528
+ return output2 if not return_residual else (output2, x)
529
+
530
+
531
+ class FusedMLP(nn.Module):
532
+ def __init__(
533
+ self,
534
+ in_features,
535
+ hidden_features=None,
536
+ out_features=None,
537
+ bias1=True,
538
+ bias2=True,
539
+ activation="gelu_approx",
540
+ return_residual=False,
541
+ checkpoint_lvl=0,
542
+ heuristic="auto",
543
+ device=None,
544
+ dtype=None,
545
+ ):
546
+ """
547
+ If process_group is not None, we're doing Tensor Parallel with sequence parallelism:
548
+ we do an all_gather of x before doing the matmul, gelu, then matmul.
549
+ Finally we do a reduce_scatter of the output.
550
+
551
+ checkpoint_lvl (increasing lvl means slower but more memory saving):
552
+ 0: no recomputation in the bwd
553
+ 1: recompute gelu_out in the bwd
554
+ 2: recompute pre_act and gelu_out in the bwd
555
+ heuristic:
556
+ -1: don't fuse gemm + gelu (separate kernel)
557
+ 0..4: use this heuristic for the algo section in the fused gemm + gelu
558
+ 'auto': heuristic will be picked automatically:
559
+ For CUDA >= 11.8, we set heuristic=0 for both fp16 and bf16 for best perf.
560
+ For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16.
561
+ For H100, we set heuristic=-1 for both fp16 and bf16 as the fused cuBlasLt implementation
562
+ is slower than the unfused version.
563
+ return_residual: whether to return the input x along with the output. This is for
564
+ performance reason: for post-norm architecture, returning the input allows us
565
+ to fuse the backward of nn.Linear with the residual connection.
566
+ """
567
+ assert checkpoint_lvl in [0, 1, 2]
568
+ assert activation in ["gelu_approx", "relu", "sqrelu"]
569
+ factory_kwargs = {"device": device, "dtype": dtype}
570
+ super().__init__()
571
+ out_features = out_features or in_features
572
+ hidden_features = hidden_features or in_features * 4
573
+ self.activation = activation
574
+ self.return_residual = return_residual
575
+ self.checkpoint_lvl = checkpoint_lvl
576
+ self.heuristic = heuristic if activation != "sqrelu" else -1
577
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)
578
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
579
+
580
+ def forward(self, x, process_group=None):
581
+ dtype = x.dtype if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype()
582
+ if self.heuristic == "auto":
583
+ if self.activation == "gelu_approx":
584
+ if torch.cuda.get_device_capability("cuda") == (9, 0):
585
+ heuristic = -1
586
+ else:
587
+ cuda_ver = tuple(map(int, torch.version.cuda.split(".")))
588
+ heuristic = 0 if cuda_ver >= (11, 8) else (1 if dtype == torch.float16 else -1)
589
+ else:
590
+ heuristic = 0
591
+ else:
592
+ heuristic = self.heuristic
593
+ out = fused_mlp_func(
594
+ x,
595
+ self.fc1.weight,
596
+ self.fc2.weight,
597
+ self.fc1.bias,
598
+ self.fc2.bias,
599
+ activation=self.activation,
600
+ save_pre_act=self.training,
601
+ return_residual=self.return_residual,
602
+ checkpoint_lvl=self.checkpoint_lvl,
603
+ heuristic=heuristic,
604
+ process_group=process_group,
605
+ )
606
+ if self.return_residual:
607
+ out, x = out
608
+ if process_group is not None:
609
+ out = reduce_scatter(out, process_group)
610
+ return out if not self.return_residual else (out, x)
611
+
612
+
613
+ class ParallelFusedMLP(nn.Module):
614
+ def __init__(
615
+ self,
616
+ in_features,
617
+ hidden_features=None,
618
+ out_features=None,
619
+ activation="gelu_approx",
620
+ process_group: ProcessGroup = None,
621
+ bias1=True,
622
+ bias2=True,
623
+ sequence_parallel=True,
624
+ checkpoint_lvl=0,
625
+ heuristic="auto",
626
+ device=None,
627
+ dtype=None,
628
+ ):
629
+ """
630
+ process_group is required. We're doing Tensor Parallel with sequence parallelism:
631
+ we do an all_gather of x before doing the matmul, gelu, then matmul.
632
+ Finally we do a reduce_scatter of the output.
633
+
634
+ checkpoint_lvl (increasing lvl means slower but more memory saving):
635
+ 0: no recomputation in the bwd
636
+ 1: recompute gelu_out in the bwd
637
+ 2: recompute pre_act and gelu_out in the bwd
638
+ heuristic:
639
+ -1: don't fuse gemm + gelu (separate kernel)
640
+ 0..4: use this heuristic for the algo section in the fused gemm + gelu
641
+ 'auto': heuristic will be picked automatically:
642
+ For CUDA >= 11.8, we set heuristic=0 for both fp16 and bf16 for best perf.
643
+ For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16.
644
+ """
645
+ assert checkpoint_lvl in [0, 1, 2]
646
+ assert activation in ["gelu_approx", "relu", "sqrelu"]
647
+ assert process_group is not None
648
+ factory_kwargs = {"device": device, "dtype": dtype}
649
+ super().__init__()
650
+ out_features = out_features or in_features
651
+ hidden_features = hidden_features or in_features * 4
652
+ self.activation = activation
653
+ self.process_group = process_group
654
+ self.sequence_parallel = sequence_parallel
655
+ self.checkpoint_lvl = checkpoint_lvl
656
+ self.heuristic = heuristic if activation != "sqrelu" else -1
657
+ self.fc1 = ColumnParallelLinear(
658
+ in_features, hidden_features, process_group, bias=bias1, **factory_kwargs
659
+ )
660
+ self.fc2 = RowParallelLinear(
661
+ hidden_features, out_features, process_group, bias=bias2, **factory_kwargs
662
+ )
663
+
664
+ def forward(self, x):
665
+ dtype = x.dtype if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype()
666
+ if self.heuristic == "auto":
667
+ if self.activation == "gelu_approx":
668
+ cuda_ver = tuple(map(int, torch.version.cuda.split(".")))
669
+ heuristic = 0 if cuda_ver >= (11, 8) else (1 if dtype == torch.float16 else -1)
670
+ else:
671
+ heuristic = 0
672
+ else:
673
+ heuristic = self.heuristic
674
+ out = fused_mlp_func(
675
+ x,
676
+ self.fc1.weight,
677
+ self.fc2.weight,
678
+ self.fc1.bias,
679
+ self.fc2.bias,
680
+ activation=self.activation,
681
+ save_pre_act=self.training,
682
+ checkpoint_lvl=self.checkpoint_lvl,
683
+ heuristic=heuristic,
684
+ process_group=self.process_group,
685
+ sequence_parallel=self.sequence_parallel,
686
+ )
687
+ reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
688
+ return reduce_fn(out, self.process_group)