mansaripo commited on
Commit
b0fd683
·
verified ·
1 Parent(s): 4d9bee1

Upload folder using huggingface_hub

Browse files
config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "CloverLMForCausalLM"
4
+ ],
5
+ "attn_backend": "flash2",
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_cloverlm.CloverLMConfig",
8
+ "AutoModelForCausalLM": "modeling_cloverlm.CloverLMForCausalLM",
9
+ "AutoTokenizer": [
10
+ "tokenization_cloverlm.CloverLMTokenizer",
11
+ null
12
+ ]
13
+ },
14
+ "d_head": 128,
15
+ "heads": 28,
16
+ "max_context": 1024,
17
+ "model_type": "cloverlm",
18
+ "num_blocks": 29,
19
+ "num_hidden_layers": 29,
20
+ "quartet_2_impl": "pseudoquant",
21
+ "ratio": 4,
22
+ "scale_type": "1/sqrt(d)",
23
+ "transformers_version": "5.3.0",
24
+ "vocab_size": 32000,
25
+ "weight_tying": true
26
+ }
configuration_cloverlm.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class CloverLMConfig(PretrainedConfig):
5
+ model_type = "cloverlm"
6
+
7
+ def __init__(
8
+ self,
9
+ vocab_size=32000,
10
+ num_blocks=4,
11
+ heads=6,
12
+ d_head=128,
13
+ ratio=3,
14
+ scale_type="1/sqrt(d)",
15
+ max_context=1024,
16
+ quartet_2_impl="pseudoquant",
17
+ weight_tying=True,
18
+ attn_backend="pytorch",
19
+ **kwargs,
20
+ ):
21
+ self.num_blocks = num_blocks
22
+ self.num_hidden_layers = num_blocks
23
+ self.heads = heads
24
+ self.d_head = d_head
25
+ self.ratio = ratio
26
+ self.scale_type = scale_type
27
+ self.max_context = max_context
28
+ self.quartet_2_impl = quartet_2_impl
29
+ self.weight_tying = weight_tying
30
+ self.attn_backend = attn_backend
31
+ super().__init__(vocab_size=vocab_size, **kwargs)
exp_mlp.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ # Normalizes on the hypersphere along dim
4
+ # (s1*...*)s-1
5
+ def sphere_norm(X, dim=-1):
6
+ return torch.nn.functional.normalize(X, dim=dim)
7
+
8
+ class SphereNorm(torch.nn.Module):
9
+ def __init__(self, dim=-1):
10
+ super().__init__()
11
+
12
+ self.dim = dim
13
+
14
+ def forward(self, X):
15
+ Y = sphere_norm(X, dim=self.dim)
16
+
17
+ return Y
18
+
19
+ def get_norm(enable, norm_type, d, bias):
20
+ if enable:
21
+ if norm_type=="layer":
22
+ norm = torch.nn.LayerNorm(d, bias=bias)
23
+ elif norm_type=="rms_learned":
24
+ norm = torch.nn.RMSNorm(d, elementwise_affine=True)
25
+ elif norm_type=="rms_const":
26
+ norm = torch.nn.RMSNorm(d, elementwise_affine=False)
27
+ elif norm_type=="sphere":
28
+ norm = SphereNorm(dim=-1)
29
+ else:
30
+ norm = None
31
+
32
+ return norm
33
+
34
+ class ReLU2(torch.nn.Module):
35
+ def forward(self, x):
36
+ y = torch.nn.functional.relu(x)**2
37
+
38
+ return y
39
+
40
+ class Abs(torch.nn.Module):
41
+ def forward(self, x):
42
+ y = x.abs()
43
+
44
+ return y
45
+
46
+ class GLU(torch.nn.Module):
47
+ def __init__(self, d0, d1, bias=True, act=torch.nn.ReLU(), quartet=True, fake_quartet=False):
48
+ super().__init__()
49
+
50
+ self.d0 = d0
51
+ self.d1 = d1
52
+ self.bias = bias
53
+ self.act = act
54
+ self.quartet = quartet
55
+ self.fake_quartet = fake_quartet
56
+
57
+ if quartet:
58
+ pass # quartet2 not available in HF mode
59
+ self.gate = torch.nn.Sequential(quartet2.linear.Quartet_II_linear(d0, d1, bias), act)
60
+
61
+ self.proj = quartet2.linear.Quartet_II_linear(d0, d1, bias)
62
+ elif fake_quartet:
63
+ from . import fake_quartet as fq
64
+ self.gate = torch.nn.Sequential(fq.FakeQuartetLinear(d0, d1, bias), act)
65
+
66
+ self.proj = fq.FakeQuartetLinear(d0, d1, bias)
67
+ else:
68
+ self.gate = torch.nn.Sequential(torch.nn.Linear(d0, d1, bias), act)
69
+
70
+ self.proj = torch.nn.Linear(d0, d1, bias)
71
+
72
+ def forward(self, x):
73
+ y = self.gate(x) * self.proj(x)
74
+
75
+ return y
76
+
77
+ class MLP2L(torch.nn.Module):
78
+ def __init__(self, d0, d1, d2, bias=True, act=torch.nn.ReLU(), dropout=0, l1_type="linear", norm_type="rms_learned", norm=False, quartet=True, fake_quartet=False):
79
+ super().__init__()
80
+
81
+ self.d0 = d0
82
+ self.d1 = d1
83
+ self.d2 = d2
84
+ self.bias = bias
85
+ self.act = act
86
+ self.dropout = dropout
87
+ self.l1_type = l1_type
88
+ self.norm_type = norm_type
89
+
90
+ if l1_type=="linear":
91
+ if quartet:
92
+ pass # quartet2 not available in HF mode
93
+ self.l1 = torch.nn.Sequential(quartet2.linear.Quartet_II_linear(d0, d1, bias), act)
94
+ elif fake_quartet:
95
+ from . import fake_quartet as fq
96
+ self.l1 = torch.nn.Sequential(fq.FakeQuartetLinear(d0, d1, bias), act)
97
+ else:
98
+ self.l1 = torch.nn.Sequential(torch.nn.Linear(d0, d1, bias), act)
99
+ elif l1_type=="glu":
100
+ self.l1 = GLU(d0, d1, bias, act, quartet, fake_quartet)
101
+
102
+ self.norm = get_norm(norm, norm_type, d1, bias)
103
+
104
+ if quartet:
105
+ pass # quartet2 not available in HF mode
106
+ self.l2 = quartet2.linear.Quartet_II_linear(d1, d2, bias)
107
+ elif fake_quartet:
108
+ from . import fake_quartet as fq
109
+ self.l2 = fq.FakeQuartetLinear(d1, d2, bias)
110
+ else:
111
+ self.l2 = torch.nn.Linear(d1, d2, bias)
112
+
113
+ def forward(self, x):
114
+ a1 = self.l1(x)
115
+ if self.norm: a1 = self.norm(a1)
116
+ a1 = torch.nn.functional.dropout(a1, p=self.dropout, training=self.training)
117
+
118
+ y = self.l2(a1)
119
+
120
+ return y
121
+
122
+ class MLP3L(torch.nn.Module):
123
+ def __init__(self, d0, d1, d2, d3, bias=True, act=torch.nn.ReLU(), dropout=0):
124
+ super().__init__()
125
+
126
+ self.d0 = d0
127
+ self.d1 = d1
128
+ self.d2 = d2
129
+ self.d3 = d3
130
+ self.bias = bias
131
+ self.act = act
132
+ self.dropout=dropout
133
+
134
+ self.l1 = torch.nn.Linear(d0, d1, bias)
135
+ self.l2 = torch.nn.Linear(d1, d2, bias)
136
+ self.l3 = torch.nn.Linear(d2, d3, bias)
137
+
138
+ def forward(self, x):
139
+ z1 = self.l1(x)
140
+ a1 = self.act(z1)
141
+ a1 = torch.nn.functional.dropout(a1, p=self.dropout, training=self.training)
142
+
143
+ z2 = self.l2(a1)
144
+ a2 = self.act(z2)
145
+ a2 = torch.nn.functional.dropout(a2, p=self.dropout, training=self.training)
146
+
147
+ y = self.l3(a2)
148
+
149
+ return y
150
+
151
+ class MLP3L_image(torch.nn.Module):
152
+ def __init__(self, res=28, d1=16, d2=16, dropout=0, classes=10):
153
+ super().__init__()
154
+
155
+ self.res = res
156
+ self.d1 = d1
157
+ self.d2 = d2
158
+ self.dropout = dropout
159
+ self.classes = classes
160
+
161
+ self.mlp = MLP3L(res*res, d1, d2, classes, dropout=dropout)
162
+
163
+ def forward(self, x):
164
+ x = x.flatten(start_dim=-3, end_dim=-1)
165
+
166
+ y = self.mlp(x)
167
+
168
+ return y
exp_transformer.py ADDED
@@ -0,0 +1,696 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import exp_mlp as mlp
3
+ from math import sqrt
4
+ import math
5
+
6
+ SCALE_TYPES = ["1/sqrt(d)", "1/d"]
7
+ POS_TYPES = ["learned", "sinusoidal", "rope", "alibi"]
8
+ BACKENDS = ["pytorch", "flash2", "flash3", "flash4", "flex", "cudnn"]
9
+ NORM_TYPES = ["layer", "rms_learned", "rms_const", "sphere"]
10
+
11
+ def get_causal(context):
12
+ causal = torch.full((context,context), True)
13
+
14
+ causal = causal.tril()
15
+
16
+ return causal
17
+
18
+ def get_sinusoidal(context, d, base=1024):
19
+ # [pos=0, pos=1, ...]
20
+ poss = torch.arange(0., context)
21
+ # [i=0, i=1, ...]
22
+ js = torch.arange(0., d//2)
23
+ # [ω0, ω1, ...]
24
+ ωs = 1/base**(2*js/d)
25
+
26
+ # [pos=0*ω0, pos=0*ω1, ...]
27
+ # [pos=1*ω0, pos=1*ω1, ...]
28
+ φs = poss[...,None] @ ωs[None,...]
29
+
30
+ # context*d
31
+ sinusoidal = torch.empty((context, d))
32
+ sinusoidal[:,0::2] = torch.sin(φs)
33
+ sinusoidal[:,1::2] = torch.cos(φs)
34
+
35
+ return sinusoidal
36
+
37
+ def get_rope(context, d, *, device, base=1024):
38
+ # [m=0, m=1, ...]
39
+ ms = torch.arange(0., context, device=device, dtype=torch.float32)
40
+ # [i=0, i=1, ...]
41
+ js = torch.arange(0., d//2, device=device, dtype=torch.float32)
42
+ # [θ0, θ1, ...]
43
+ θs = 1/base**(2*js/d)
44
+
45
+ # [m=0*θ0, m=0*θ1, ...]
46
+ # [m=1*θ0, m=1*θ1, ...]
47
+ φs = ms[...,None] @ θs[None,...]
48
+
49
+ # context*d/2
50
+ cos = torch.cos(φs)
51
+ sin = torch.sin(φs)
52
+ # context*d
53
+ cos = cos.repeat_interleave(repeats=2, dim=1)
54
+ sin = sin.repeat_interleave(repeats=2, dim=1)
55
+
56
+ # 2*context*d
57
+ rope = torch.stack((cos,sin))
58
+
59
+ return rope
60
+
61
+ # (batches*)context*d
62
+ def apply_rope(X, rope):
63
+ X_ = torch.empty_like(X)
64
+ X_[...,0::2] = -X[...,1::2]
65
+ X_[...,1::2] = X[...,0::2]
66
+
67
+ # context*d
68
+ cos = rope[0]
69
+ sin = rope[1]
70
+
71
+ Y = X*cos + X_*sin
72
+
73
+ return Y.to(X.dtype)
74
+
75
+ def get_m(heads, base=2, exp=8):
76
+ m = base**( (-exp/heads)*torch.arange(1,heads+1) )
77
+
78
+ return m
79
+
80
+ def get_alibi(heads, context):
81
+ # 1*context*1
82
+ i = torch.arange(0, context)[None,:,None]
83
+ # 1*1*context
84
+ j = i.mT
85
+ # heads*1*1
86
+ m = get_m(heads)[:,None,None]
87
+
88
+ alibi = -torch.abs(i - j)*m
89
+
90
+ return alibi
91
+
92
+ def get_swa(context, window):
93
+ # context*1
94
+ i = torch.arange(0, context).unsqueeze(-1)
95
+ # 1*context
96
+ j = i.T
97
+
98
+ swa = torch.abs(i - j) <= window
99
+
100
+ return swa
101
+
102
+ # (batches*)heads/groups*context*d_head
103
+ def sdpa_pytorch(Q, K, V, causal=None, alibi=None, swa=None, scale=None, return_A=False):
104
+ if scale is None:
105
+ d_head = Q.shape[-1]
106
+ scale = 1/sqrt(d_head)
107
+
108
+ # (batches*)heads*context*d_head
109
+ heads = Q.shape[-3]
110
+ groups = K.shape[-3]
111
+ ratio = heads//groups
112
+ # PyTorch only broadcasts when the operation is not defined otherwise. MM does not involve the batch dimensions, and hence PyTorch does not broadcast them.
113
+ K = K.repeat_interleave(repeats=ratio, dim=-3)
114
+ V = V.repeat_interleave(repeats=ratio, dim=-3)
115
+
116
+ # (batches*)heads*context*context
117
+ A__ = Q @ K.mT
118
+
119
+ # batches*heads*context*context
120
+ A_ = scale*A__
121
+ # (batches*)heads*context*context
122
+ A_ = A_.reshape(A__.shape)
123
+
124
+ if alibi is not None:
125
+ A_ = A_ + alibi
126
+ if causal is not None:
127
+ A_.masked_fill_(~causal, -float("inf"))
128
+ if swa is not None:
129
+ A_.masked_fill_(~swa, -float("inf"))
130
+
131
+ A = torch.softmax(A_, dim=-1)
132
+
133
+ # (batches*)heads*context*d_head
134
+ Y = A @ V
135
+
136
+ if not return_A:
137
+ return Y
138
+ else:
139
+ return Y, A__, A_, A
140
+
141
+ # (batches*)heads/groups*context*d_head
142
+ def sdpa_flash(Q, K, V, causal=False, alibi=None, swa=None, scale=None, backend="flash2"):
143
+ if (alibi is not None) and backend != "flash2":
144
+ print("\x1b[93;3m[WARNING]: backend={backend} does not support ALiBi. Hence, we force backend=flash2.\x1b[0m")
145
+ backend = "flash2"
146
+
147
+ # FlashAttention only supports float scale
148
+ if isinstance(scale, torch.Tensor):
149
+ Q_shape = Q.shape
150
+ # batches*heads*context*d_head
151
+ Q = scale*Q
152
+ # (batches*)heads*context*d_head
153
+ Q = Q.reshape(Q_shape)
154
+
155
+ scale = 1
156
+
157
+ # FlashAttention2 only supports BF16 and FP16
158
+ if Q.dtype in [torch.bfloat16, torch.float16]:
159
+ dtype = Q.dtype
160
+ else:
161
+ dtype = torch.bfloat16
162
+
163
+ heads = Q.shape[-3]
164
+ groups = K.shape[-3]
165
+ context = Q.shape[-2]
166
+ d_head = Q.shape[-1]
167
+
168
+ # CAUTION: FlashAttention expects batches*context*heads/groups*d_head
169
+ Q = Q.movedim(-3,-2).reshape(-1,context,heads,d_head)
170
+ K = K.movedim(-3,-2).reshape(-1,context,groups,d_head)
171
+ V = V.movedim(-3,-2).reshape(-1,context,groups,d_head)
172
+
173
+ if swa is None:
174
+ swa = (-1,-1)
175
+
176
+ if backend=="flash2":
177
+ import flash_attn
178
+ Y = flash_attn.flash_attn_func(Q.to(dtype), K.to(dtype), V.to(dtype), causal=causal, alibi_slopes=alibi, window_size=swa, softmax_scale=scale)
179
+ elif backend=="flash3":
180
+ import flash_attn_interface
181
+ Y = flash_attn_interface.flash_attn_func(Q.to(dtype), K.to(dtype), V.to(dtype), causal=causal, window_size=swa, softmax_scale=scale)
182
+ elif backend=="flash4":
183
+ import flash_attn.cute
184
+ # FlashAttention4 returns (out, lse)
185
+ Y = flash_attn.cute.flash_attn_func(Q.to(dtype), K.to(dtype), V.to(dtype), causal=causal, window_size=swa, softmax_scale=scale)[0]
186
+
187
+ Y = Y.to(Q.dtype)
188
+
189
+ # Restore the shape to: (batches*)heads*context*d_head
190
+ Y = Y.movedim(-3,-2).squeeze(0)
191
+
192
+ return Y
193
+
194
+ # (batches*)heads/groups*context*d_head
195
+ def sdpa_flex():
196
+ return None
197
+
198
+ # (batches*)heads/groups*context*d_head
199
+ def sdpa_cudnn():
200
+ return None
201
+
202
+ def sdpa_wrapper(Q, K, V, causal=None, alibi=None, swa=None, scale=None, return_A=False, backend="flash2"):
203
+ if backend=="pytorch":
204
+ return sdpa_pytorch(Q, K, V, causal, alibi, swa, scale, return_A)
205
+ elif backend in {"flash2", "flash3", "flash4"}:
206
+ return sdpa_flash(Q, K, V, causal, alibi, swa, scale, backend)
207
+ elif backend=="flex":
208
+ return sdpa_flex()
209
+ elif backend=="cudnn":
210
+ return sdpa_cudnn()
211
+
212
+ def test_sdpa():
213
+ batches = 32
214
+ heads = 12
215
+ context = 1024
216
+ d_head = 64
217
+ window = 256
218
+ groups = 4
219
+ dtype = torch.bfloat16
220
+
221
+ print("\x1b[1mbfloat16\x1b[0m",end="")
222
+ Q = torch.rand((batches, heads, context, d_head)).to("cuda:0", dtype)
223
+ K = torch.rand((batches, heads, context, d_head)).to("cuda:0", dtype)
224
+ V = torch.rand((batches, heads, context, d_head)).to("cuda:0", dtype)
225
+ pytorch = sdpa_wrapper(Q, K, V, backend="pytorch")
226
+ flash2 = sdpa_wrapper(Q, K, V, backend="flash2")
227
+ torch.testing.assert_close(flash2, pytorch, check_dtype=False)
228
+ flash3 = sdpa_wrapper(Q, K, V, backend="flash3")
229
+ torch.testing.assert_close(flash3, pytorch, check_dtype=False)
230
+ flash4 = sdpa_wrapper(Q, K, V, backend="flash4")
231
+ torch.testing.assert_close(flash4, pytorch, check_dtype=False)
232
+ print("\x1b[32m ✔\x1b[0m")
233
+
234
+ print("\x1b[1mcausal\x1b[0m",end="")
235
+ pytorch = sdpa_wrapper(Q, K, V, causal=get_causal(context).to("cuda:0"), backend="pytorch")
236
+ flash2 = sdpa_wrapper(Q, K, V, causal=True, backend="flash2")
237
+ torch.testing.assert_close(flash2, pytorch, check_dtype=False)
238
+ flash3 = sdpa_wrapper(Q, K, V, causal=True, backend="flash3")
239
+ torch.testing.assert_close(flash3, pytorch, check_dtype=False)
240
+ flash4 = sdpa_wrapper(Q, K, V, causal=True, backend="flash4")
241
+ torch.testing.assert_close(flash4, pytorch, check_dtype=False)
242
+ print("\x1b[32m ✔\x1b[0m")
243
+
244
+ print("\x1b[1malibi\x1b[0m",end="")
245
+ pytorch = sdpa_wrapper(Q, K, V, alibi=get_alibi(heads,context).to("cuda:0",dtype), backend="pytorch")
246
+ flash2 = sdpa_wrapper(Q, K, V, alibi=get_m(heads).to("cuda:0"), backend="flash2")
247
+ torch.testing.assert_close(flash2, pytorch, check_dtype=False)
248
+ # ALiBi not supported on FlashAttention3/4
249
+ print("\x1b[32m ✔\x1b[0m")
250
+
251
+ print("\x1b[1mswa\x1b[0m",end="")
252
+ pytorch = sdpa_wrapper(Q, K, V, swa=get_swa(context,window).to("cuda:0"), backend="pytorch")
253
+ flash2 = sdpa_wrapper(Q, K, V, swa=(window,window), backend="flash2")
254
+ torch.testing.assert_close(flash2, pytorch, check_dtype=False)
255
+ flash3 = sdpa_wrapper(Q, K, V, swa=(window,window), backend="flash3")
256
+ torch.testing.assert_close(flash3, pytorch, check_dtype=False)
257
+ flash4 = sdpa_wrapper(Q, K, V, swa=(window,window), backend="flash4")
258
+ torch.testing.assert_close(flash4, pytorch, check_dtype=False)
259
+ print("\x1b[32m ✔\x1b[0m")
260
+
261
+ print("\x1b[1mcausal+alibi\x1b[0m",end="")
262
+ pytorch = sdpa_wrapper(Q, K, V, causal=get_causal(context).to("cuda:0"), alibi=get_alibi(heads,context).to("cuda:0",dtype), backend="pytorch")
263
+ flash2 = sdpa_wrapper(Q, K, V, causal=True, alibi=get_m(heads).to("cuda:0"), backend="flash2")
264
+ torch.testing.assert_close(flash2, pytorch, check_dtype=False)
265
+ # ALiBi not supported on FlashAttention3/4
266
+ print("\x1b[32m ✔\x1b[0m")
267
+
268
+ print("\x1b[1mcausal+swa\x1b[0m",end="")
269
+ pytorch = sdpa_wrapper(Q, K, V, causal=get_causal(context).to("cuda:0"), swa=get_swa(context,window).to("cuda:0"), backend="pytorch")
270
+ flash2 = sdpa_wrapper(Q, K, V, causal=True, swa=(window,window), backend="flash2")
271
+ torch.testing.assert_close(flash2, pytorch, check_dtype=False)
272
+ flash3 = sdpa_wrapper(Q, K, V, causal=True, swa=(window,window), backend="flash3")
273
+ torch.testing.assert_close(flash3, pytorch, check_dtype=False)
274
+ flash4 = sdpa_wrapper(Q, K, V, causal=True, swa=(window,window), backend="flash4")
275
+ torch.testing.assert_close(flash4, pytorch, check_dtype=False)
276
+ print("\x1b[32m ✔\x1b[0m")
277
+
278
+ print("\x1b[1mGQA\x1b[0m",end="")
279
+ Q = torch.rand((batches, heads, context, d_head)).to("cuda:0", dtype)
280
+ K = torch.rand((batches, groups, context, d_head)).to("cuda:0", dtype)
281
+ V = torch.rand((batches, groups, context, d_head)).to("cuda:0", dtype)
282
+ pytorch = sdpa_wrapper(Q, K, V, backend="pytorch")
283
+ flash2 = sdpa_wrapper(Q, K, V, backend="flash2")
284
+ torch.testing.assert_close(flash2, pytorch, check_dtype=False)
285
+ flash3 = sdpa_wrapper(Q, K, V, backend="flash3")
286
+ torch.testing.assert_close(flash3, pytorch, check_dtype=False)
287
+ flash4 = sdpa_wrapper(Q, K, V, backend="flash4")
288
+ torch.testing.assert_close(flash4, pytorch, check_dtype=False)
289
+ print("\x1b[32m ✔\x1b[0m")
290
+
291
+ class MHSA(torch.nn.Module):
292
+ def __init__(self, heads, d_head, scale_type="1/sqrt(d)", ratio=1, qk_norm=True, quartet=True, fake_quartet=False):
293
+ super().__init__()
294
+
295
+ self.heads = heads
296
+ self.d_head = d_head
297
+ self.d = heads * d_head
298
+ self.scale_type = scale_type
299
+ self.ratio = ratio
300
+ self.groups = heads//ratio
301
+ self.d_KV = self.groups * d_head
302
+ self.qk_norm = qk_norm
303
+ if qk_norm:
304
+ # (batches*)heads*context*d_head
305
+ scale = torch.full((1,heads,1,1), sqrt(d_head))
306
+ self.scale = torch.nn.Parameter(scale)
307
+ else:
308
+ if scale_type=="1/sqrt(d)":
309
+ self.scale = 1/sqrt(d_head)
310
+ elif scale_type=="1/d":
311
+ self.scale = 1/d_head
312
+ self.quartet = quartet
313
+ self.fake_quartet = fake_quartet
314
+
315
+ # Packing QKV gives negligible speed gains, while not allowing GQA, hurting code clarity and having side effects with μP
316
+ if quartet:
317
+ pass # quartet2 not available in HF mode
318
+ self.lq = quartet2.linear.Quartet_II_linear(self.d, self.d, bias=False)
319
+ self.lk = quartet2.linear.Quartet_II_linear(self.d, self.d_KV, bias=False)
320
+ self.lv = quartet2.linear.Quartet_II_linear(self.d, self.d_KV, bias=False)
321
+
322
+ self.lo = quartet2.linear.Quartet_II_linear(self.d, self.d, bias=False)
323
+ elif fake_quartet:
324
+ from . import fake_quartet as fq
325
+ self.lq = fq.FakeQuartetLinear(self.d, self.d, bias=False)
326
+ self.lk = fq.FakeQuartetLinear(self.d, self.d_KV, bias=False)
327
+ self.lv = fq.FakeQuartetLinear(self.d, self.d_KV, bias=False)
328
+
329
+ self.lo = fq.FakeQuartetLinear(self.d, self.d, bias=False)
330
+ else:
331
+ self.lq = torch.nn.Linear(self.d, self.d, bias=False)
332
+ self.lk = torch.nn.Linear(self.d, self.d_KV, bias=False)
333
+ self.lv = torch.nn.Linear(self.d, self.d_KV, bias=False)
334
+
335
+ self.lo = torch.nn.Linear(self.d, self.d, bias=False)
336
+
337
+ # (batches*)context*d
338
+ def forward(self, X, causal=None, rope=None, alibi=None, swa=None, return_A=False, backend="flash2"):
339
+ # (batches*)context*d
340
+ Q = self.lq(X)
341
+ # (batches*)context*d_KV
342
+ K = self.lk(X)
343
+ V = self.lv(X)
344
+
345
+ # (batches*)context*heads*d_head
346
+ Q = Q.unflatten(dim=-1, sizes=(self.heads, self.d_head))
347
+ # (batches*)context*groups*d_head
348
+ K = K.unflatten(dim=-1, sizes=(self.groups, self.d_head))
349
+ V = V.unflatten(dim=-1, sizes=(self.groups, self.d_head))
350
+
351
+ # (batches*)heads*context*d_head
352
+ Q = Q.movedim(-3,-2)
353
+ # (batches*)groups*context*d_head
354
+ K = K.movedim(-3,-2)
355
+ V = V.movedim(-3,-2)
356
+
357
+ if rope is not None:
358
+ Q = apply_rope(Q,rope)
359
+ K = apply_rope(K,rope)
360
+
361
+ # After RoPE
362
+ if self.qk_norm:
363
+ Q = mlp.sphere_norm(Q)
364
+ K = mlp.sphere_norm(K)
365
+
366
+ # (batches*)heads*context*d_head
367
+ if not return_A:
368
+ Y = sdpa_wrapper(Q, K, V, causal, alibi, swa, self.scale, return_A, backend)
369
+ else:
370
+ Y, A__, A_, A = sdpa_wrapper(Q, K, V, causal, alibi, swa, self.scale, return_A, backend)
371
+ # (batches*)context*heads*d_head
372
+ Y = Y.movedim(-3,-2)
373
+ # (batches*)context*d
374
+ Y = Y.flatten(-2,-1)
375
+
376
+ Y = self.lo(Y)
377
+
378
+ if not return_A:
379
+ return Y
380
+ else:
381
+ return Y, A__, A_, A
382
+
383
+ class Block(torch.nn.Module):
384
+ def __init__(self, heads, d_head, scale_type="1/sqrt(d)", ratio=1, exp_factor=4, dropout=0, norm_type="rms_learned", bias=False, act=mlp.ReLU2(), l1_type="linear", pre_att_norm=False, qk_norm=True, out_att_norm=True, pre_mlp_norm=False, act_norm=False, out_mlp_norm=True, quartet=True, fake_quartet=False):
385
+ super().__init__()
386
+
387
+ self.heads = heads
388
+ self.d_head = d_head
389
+ self.d = heads * d_head
390
+ self.scale_type = scale_type
391
+ self.ratio = ratio
392
+ self.groups = heads//ratio
393
+ self.exp_factor = exp_factor
394
+ self.d_hidden = int(exp_factor*self.d)
395
+ self.dropout = dropout
396
+ self.norm_type = norm_type
397
+ self.bias = bias
398
+ self.act = act
399
+ self.l1_type = l1_type
400
+
401
+ self.mhsa = MHSA(heads, d_head, scale_type, ratio, qk_norm, quartet, fake_quartet)
402
+ self.pre_att_norm = mlp.get_norm(pre_att_norm, norm_type, self.d, bias)
403
+ self.out_att_norm = mlp.get_norm(out_att_norm, norm_type, self.d, bias)
404
+
405
+ self.mlp = mlp.MLP2L(self.d, self.d_hidden, self.d, bias, act, dropout, l1_type, norm_type, act_norm, quartet, fake_quartet)
406
+ self.pre_mlp_norm = mlp.get_norm(pre_mlp_norm, norm_type, self.d, bias)
407
+ self.out_mlp_norm = mlp.get_norm(out_mlp_norm, norm_type, self.d, bias)
408
+
409
+ self.quartet = quartet
410
+ self.fake_quartet = fake_quartet
411
+
412
+ def forward(self, X, causal=None, rope=None, alibi=None, swa=None, return_res=False, return_A=False, backend="flash2"):
413
+ mhsa = self.mhsa(self.pre_att_norm(X) if self.pre_att_norm else X, causal, rope, alibi, swa, return_A, backend)
414
+ if not return_A:
415
+ Y = mhsa
416
+ else:
417
+ Y, A__, A_, A = mhsa
418
+
419
+ if self.out_att_norm: Y = self.out_att_norm(Y)
420
+
421
+ Y_ = torch.nn.functional.dropout(Y, p=self.dropout, training=self.training)
422
+ Y__ = X + Y_
423
+
424
+ Z = self.mlp(self.pre_mlp_norm(Y__) if self.pre_mlp_norm else Y__)
425
+
426
+ if self.out_mlp_norm: Z = self.out_mlp_norm(Z)
427
+
428
+ Z_ = torch.nn.functional.dropout(Z, p=self.dropout, training=self.training)
429
+ Z__ = Y__ + Z_
430
+
431
+ if not return_res:
432
+ if not return_A:
433
+ return Z__
434
+ else:
435
+ return Z__, A__, A_, A
436
+ else:
437
+ if not return_A:
438
+ return Z__, Y__
439
+ else:
440
+ return Z__, Y__, A__, A_, A
441
+
442
+ class Transformer(torch.nn.Module):
443
+ def __init__(self, vocab_size=50304, num_blocks=12, heads=12, d_head=64, scale_type="1/sqrt(d)", ratio=1, is_causal=True, window=None, backend="flash2", exp_factor=4, dropout=0, pos_type="rope", max_context=128, norm_type="rms_learned", bias=False, act=mlp.ReLU2(), l1_type="linear", std=0.02, test=False, weight_tying=True, emb_norm=False, pre_att_norm=False, qk_norm=True, out_att_norm=True, pre_mlp_norm=False, act_norm=False, out_mlp_norm=True, out_norm=True, fix_norm=False, quartet=True, fake_quartet=False):
444
+ super().__init__()
445
+
446
+ self.vocab_size = vocab_size
447
+ self.num_blocks = num_blocks
448
+ self.heads = heads
449
+ self.d_head = d_head
450
+ self.d = heads * d_head
451
+ self.scale_type = scale_type
452
+ self.ratio = ratio
453
+ self.groups = heads//ratio
454
+ self.is_causal = is_causal
455
+ self.window = window
456
+ self.backend = backend
457
+ self.exp_factor = exp_factor
458
+ self.dropout = dropout
459
+ self.pos_type = pos_type
460
+ self.max_context = max_context
461
+ self.norm_type = norm_type
462
+ self.bias = bias
463
+ self.act = act
464
+ self.l1_type = l1_type
465
+ self.weight_tying = weight_tying
466
+ self.fix_norm = fix_norm
467
+ self.quartet = quartet
468
+ self.fake_quartet = fake_quartet
469
+
470
+ self.emb = torch.nn.Embedding(vocab_size, self.d)
471
+
472
+ self.emb_norm = mlp.get_norm(emb_norm, norm_type, self.d, bias)
473
+
474
+ if pos_type == "learned":
475
+ pos = torch.rand((max_context, self.d))
476
+ self.pos = torch.nn.Parameter(pos)
477
+
478
+ self.blocks = torch.nn.Sequential(*[Block(heads, d_head, scale_type, ratio, exp_factor, dropout, norm_type, bias, act, l1_type, pre_att_norm, qk_norm, out_att_norm, pre_mlp_norm, act_norm, out_mlp_norm, quartet, fake_quartet) for _ in range(num_blocks)])
479
+
480
+ self.out_norm = mlp.get_norm(out_norm, norm_type, self.d, bias)
481
+
482
+ self.linear = torch.nn.Linear(self.d, vocab_size, bias=False)
483
+
484
+ if weight_tying: self.emb.weight = self.linear.weight
485
+
486
+ self.init(std, test)
487
+
488
+ if fake_quartet:
489
+ for m in self.modules():
490
+ if isinstance(m, (torch.nn.LayerNorm, torch.nn.RMSNorm, torch.nn.Embedding)):
491
+ m.to(torch.bfloat16)
492
+
493
+ def init(self, std=0.02, test=False):
494
+ if test: print("\x1b[1m%36.36s %8.8s %8.8s %8.8s\x1b[0m" % ("parameter_name", "suffix", "mean", "std"))
495
+ for parameter_name, parameter in self.named_parameters():
496
+ parent_name, _, suffix = parameter_name.rpartition(".")
497
+ parent = self.get_submodule(parent_name)
498
+
499
+ if isinstance(parent, (torch.nn.Linear, torch.nn.Embedding)) and suffix=="weight":
500
+ torch.nn.init.normal_(parameter, 0, std)
501
+ elif isinstance(parent, (torch.nn.Linear, torch.nn.LayerNorm)) and suffix=="bias":
502
+ torch.nn.init.zeros_(parameter)
503
+ elif isinstance(parent, (torch.nn.LayerNorm, torch.nn.RMSNorm)) and suffix=="weight":
504
+ torch.nn.init.ones_(parameter)
505
+ else:
506
+ # pos
507
+ if parameter.ndim == 2:
508
+ torch.nn.init.zeros_(parameter)
509
+ # scale
510
+ elif parameter.ndim == 4:
511
+ torch.nn.init.constant_(parameter, sqrt(self.d_head))
512
+
513
+ if test:
514
+ print("%36.36s %8.8s %8.8s %8.8s\x1b[0m" % (parameter_name, suffix, "%f" % parameter.mean(), "%f" % parameter.std()))
515
+
516
+ # (batches*)context
517
+ def forward(self, ids, return_res=False, return_A=False):
518
+ context = ids.shape[-1]
519
+
520
+ if return_A:
521
+ # (batches*)num_blocks*heads*context*context
522
+ A__ = torch.empty(*ids.shape[:-1], self.num_blocks, self.heads, context, context)
523
+ A_ = torch.empty_like(A__)
524
+ A = torch.empty_like(A__)
525
+
526
+ # (batches*)context*d
527
+ X = self.emb(ids)
528
+
529
+ if return_res:
530
+ res_in = X
531
+ # (batches*)num_blocks*context*d
532
+ res_att = torch.empty(*ids.shape[:-1], self.num_blocks, context, self.d)
533
+ res_mlp = torch.empty(*ids.shape[:-1], self.num_blocks, context, self.d)
534
+
535
+ # Recompute in every batch in case context changes
536
+ if self.is_causal:
537
+ if self.backend=="pytorch":
538
+ causal = get_causal(context).to(ids.device)
539
+ elif self.backend in {"flash2", "flash3", "flash4"}:
540
+ causal = True
541
+ elif self.backend=="flex":
542
+ causal = causal_mod
543
+ elif self.backend=="cudnn":
544
+ # right_bound
545
+ causal = 0
546
+ else: causal = None
547
+
548
+ if self.pos_type == "sinusoidal":
549
+ pos = get_sinusoidal(context, self.d).to(ids.device)
550
+ X = X + pos
551
+
552
+ if self.pos_type == "learned":
553
+ X = X + self.pos[:context,:]
554
+
555
+ if self.pos_type == "rope":
556
+ rope = get_rope(context, self.d_head, device=ids.device)
557
+ else: rope = None
558
+
559
+ if self.pos_type == "alibi":
560
+ if self.backend=="pytorch":
561
+ alibi = get_alibi(self.heads, context).to(ids.device)
562
+ elif self.backend in {"flash2", "flash3", "flash4"}:
563
+ alibi = get_m(self.heads).to(ids.device)
564
+ elif self.backend=="flex":
565
+ alibi = alibi_mod
566
+ elif self.backend=="cudnn":
567
+ alibi = True
568
+ else: alibi = None
569
+
570
+ if self.window is not None:
571
+ if self.backend=="pytorch":
572
+ swa = get_swa(context, self.window).to(ids.device)
573
+ elif self.backend in {"flash2", "flash3", "flash4"}:
574
+ swa = (self.window, self.window)
575
+ elif self.backend=="flex":
576
+ swa = swa_mod
577
+ elif self.backend=="cudnn":
578
+ # left_bound
579
+ swa = self.window
580
+ else: swa = None
581
+
582
+ # After positional encoding
583
+ if self.emb_norm: X = self.emb_norm(X)
584
+
585
+ X_ = torch.nn.functional.dropout(X, p=self.dropout, training=self.training)
586
+
587
+ Y = X_
588
+ for i, block in enumerate(self.blocks):
589
+ if not return_res:
590
+ if not return_A:
591
+ Y = block(Y, causal, rope, alibi, swa, return_res, return_A, self.backend)
592
+ else:
593
+ Y, A__[...,i,:,:,:], A_[...,i,:,:,:], A[...,i,:,:,:] = block(Y, causal, rope, alibi, swa, return_res, return_A, self.backend)
594
+ else:
595
+ if not return_A:
596
+ Y, res_att[...,i,:,:] = block(Y, causal, rope, alibi, swa, return_res, return_A, self.backend)
597
+ res_mlp[...,i,:,:]= Y
598
+ else:
599
+ Y, res_att[...,i,:,:], A__[...,i,:,:,:], A_[...,i,:,:,:], A[...,i,:,:,:] = block(Y, causal, rope, alibi, swa, return_res, return_A, self.backend)
600
+ res_mlp[...,i,:,:]= Y
601
+
602
+ if self.out_norm: Y = self.out_norm(Y)
603
+
604
+ # (batches*)context*vocab_size
605
+ if self.fix_norm:
606
+ Z = torch.nn.functional.linear(Y, mlp.sphere_norm(self.linear.weight))
607
+ else:
608
+ Z = self.linear(Y)
609
+
610
+ if not return_res:
611
+ if not return_A:
612
+ return Z
613
+ else:
614
+ return Z, A__, A_, A
615
+ else:
616
+ if not return_A:
617
+ return Z, res_in, res_att, res_mlp
618
+ else:
619
+ return Z, res_in, res_att, res_mlp, A__, A_, A
620
+
621
+ def get_attention_header(transformer):
622
+ attention_header = ""
623
+
624
+ for block in range(transformer.num_blocks):
625
+ for head in range(transformer.heads):
626
+ attention_header += f"block{block}.head{head} "
627
+
628
+ # Remove last space
629
+ attention_header = attention_header[:-1]
630
+
631
+ return attention_header
632
+
633
+ def get_attention(W):
634
+ attention = ""
635
+
636
+ for block in range(W.shape[0]):
637
+ for head in range(W.shape[1]):
638
+ # rows->y, columns->x
639
+ attention += "%.2f " % W[block, head]
640
+
641
+ # Remove last space
642
+ attention = attention[:-1]
643
+
644
+ return attention
645
+
646
+ def get_similarity_header(transformer):
647
+ similarity_header = "embedding "
648
+
649
+ for block in range(transformer.num_blocks):
650
+ similarity_header += f"block{block} "
651
+
652
+ # Remove last space
653
+ similarity_header = similarity_header[:-1]
654
+
655
+ return similarity_header
656
+
657
+ def get_similarity(embeddings_x, embeddings_y):
658
+ similarity = ""
659
+
660
+ for block in range(embeddings_x.shape[0]):
661
+ similarity += "%.2f " % torch.nn.functional.cosine_similarity(embeddings_x[block,:], embeddings_y[block,:], dim=0)
662
+
663
+ # Remove last space
664
+ similarity = similarity[:-1]
665
+
666
+ return similarity
667
+
668
+ def get_clustering_header(transformer):
669
+ clustering_header = "embedding.random.x embedding.random.y "\
670
+ "embedding.pca.x embedding.pca.y "\
671
+ "embedding.mds.x embedding.mds.y "\
672
+ "embedding.tsne.x embedding.tsne.y "\
673
+ "embedding.umap.x embedding.umap.y "
674
+
675
+ for block in range(transformer.num_blocks):
676
+ clustering_header += f"block{block}.random.x block{block}.random.y "\
677
+ f"block{block}.pca.x block{block}.pca.y "\
678
+ f"block{block}.mds.x block{block}.mds.y "\
679
+ f"block{block}.tsne.x block{block}.tsne.y "\
680
+ f"block{block}.umap.x block{block}.umap.y "
681
+
682
+ # Remove last space
683
+ clustering_header = clustering_header[:-1]
684
+
685
+ return clustering_header
686
+
687
+ def get_clustering(random, pca, mds, tsne, umap):
688
+ clustering = ""
689
+
690
+ for block in range(random.shape[0]):
691
+ clustering += "%f %f %f %f %f %f %f %f %f %f " % (random[block,0], random[block,1], pca[block,0], pca[block,1], mds[block,0], mds[block,1], tsne[block,0], tsne[block,1], umap[block,0], umap[block,1])
692
+
693
+ # Remove last space
694
+ clustering = clustering[:-1]
695
+
696
+ return clustering
fake_quartet.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from random import randint
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ import triton
7
+ import triton.language as tl
8
+ from scipy.linalg import hadamard
9
+
10
+
11
+
12
+ def get_hadamard_matrix(group_size: int, dtype: torch.dtype, device):
13
+ return torch.tensor(
14
+ hadamard(group_size) * group_size**-0.5,
15
+ dtype=dtype,
16
+ device=device,
17
+ requires_grad=False,
18
+ )
19
+
20
+
21
+ def rerotate_hadamard(hadamard_matrix):
22
+ signs = torch.diag(
23
+ torch.randint(
24
+ 0, 2, (hadamard_matrix.size(0),),
25
+ device=hadamard_matrix.device,
26
+ dtype=hadamard_matrix.dtype,
27
+ ) * 2 - 1
28
+ )
29
+ return hadamard_matrix @ signs
30
+
31
+
32
+
33
+ @triton.jit
34
+ def _rtn_fp4(x):
35
+ x_abs = tl.abs(x)
36
+ x_sign = tl.where(x > 0, 1, -1)
37
+ x_fp4_abs = tl.where(
38
+ x_abs >= 5, 6,
39
+ tl.where(x_abs >= 3.5, 4,
40
+ tl.where(x_abs >= 2.5, 3,
41
+ tl.where(x_abs >= 1.75, 2,
42
+ tl.where(x_abs >= 1.25, 1.5,
43
+ tl.where(x_abs >= 0.75, 1,
44
+ tl.where(x_abs >= 0.25, 0.5,
45
+ 0.0)))))))
46
+ return x_fp4_abs * x_sign
47
+
48
+
49
+ @triton.jit
50
+ def _get_scales(x, amax, val_max, scales_max):
51
+ s_dec = tl.where(amax == 0.0, 1.0, amax / scales_max / val_max)
52
+ s_dec_b = tl.max(tl.abs(x), axis=-1, keep_dims=True) / val_max
53
+ s_dec_b_e4m3 = (s_dec_b / s_dec).to(tl.float8e4nv).to(tl.float32)
54
+ s_dec_b_e4m3 = tl.where(s_dec_b_e4m3 == 0, 1.0, s_dec_b_e4m3)
55
+ return s_dec_b_e4m3, s_dec
56
+
57
+
58
+ @triton.jit
59
+ def _get_alt_scales(x, val_max, s_dec):
60
+ s_dec_b = tl.max(tl.abs(x), axis=-1, keep_dims=True) / val_max
61
+ s_dec_b_e4m3 = (s_dec_b * (6 / 4) / s_dec).to(tl.float8e4nv).to(tl.float32)
62
+ s_dec_b_e4m3 = tl.where(s_dec_b_e4m3 == 0, 1.0, s_dec_b_e4m3)
63
+ return s_dec_b_e4m3
64
+
65
+
66
+ @triton.autotune(
67
+ configs=[
68
+ triton.Config({"BLOCK_SIZE": 64 * 32}),
69
+ triton.Config({"BLOCK_SIZE": 128 * 32}),
70
+ triton.Config({"BLOCK_SIZE": 256 * 32}),
71
+ triton.Config({"BLOCK_SIZE": 512 * 32}),
72
+ ],
73
+ key=[],
74
+ )
75
+ @triton.jit
76
+ def _rtn_1x16s_fp4_kernel(
77
+ x_ptr, amax_ptr, output_ptr,
78
+ n_elements: tl.constexpr,
79
+ scale_override: tl.constexpr,
80
+ group_size: tl.constexpr,
81
+ four_over_six: tl.constexpr,
82
+ BLOCK_SIZE: tl.constexpr,
83
+ ):
84
+ pid = tl.program_id(0)
85
+ start_idx = pid * BLOCK_SIZE
86
+ offsets = start_idx + tl.arange(0, BLOCK_SIZE)
87
+ mask = offsets < n_elements
88
+ x_flat = tl.load(x_ptr + offsets, mask=mask, other=0.0)
89
+
90
+ x_grouped = tl.reshape(x_flat, (BLOCK_SIZE // group_size, group_size))
91
+
92
+ scales_max = 256.00 if four_over_six else 448.00
93
+ val_max = 6.0 / scale_override
94
+ amax = tl.load(amax_ptr)
95
+
96
+ s_dec_b_e4m3, s_dec = _get_scales(x_grouped, amax, val_max, scales_max)
97
+ x_scaled = x_grouped / (s_dec_b_e4m3 * s_dec)
98
+
99
+ x_fp4 = _rtn_fp4(x_scaled)
100
+ x_dequantized = x_fp4 * (s_dec_b_e4m3 * s_dec)
101
+
102
+ if not four_over_six:
103
+ best_x_dequantized = x_dequantized
104
+ else:
105
+ alt_s_dec_b_e4m3 = _get_alt_scales(x_grouped, val_max, s_dec)
106
+ alt_x_scaled = x_grouped / (alt_s_dec_b_e4m3 * s_dec)
107
+ alt_x_fp4 = _rtn_fp4(alt_x_scaled)
108
+ alt_x_dequantized = alt_x_fp4 * (alt_s_dec_b_e4m3 * s_dec)
109
+
110
+ error_six = tl.sum((x_grouped - x_dequantized) * (x_grouped - x_dequantized), axis=-1, keep_dims=True)
111
+ error_four = tl.sum((x_grouped - alt_x_dequantized) * (x_grouped - alt_x_dequantized), axis=-1, keep_dims=True)
112
+ best_x_dequantized = tl.where(error_six <= error_four, x_dequantized, alt_x_dequantized)
113
+
114
+ x_dequantized_flat = tl.reshape(best_x_dequantized, (BLOCK_SIZE,))
115
+ tl.store(output_ptr + offsets, x_dequantized_flat, mask=mask)
116
+
117
+
118
+ @torch.compiler.disable()
119
+ def rtn_1x16s_fp4(x, scale_override: float, group_size: int, four_over_six: bool):
120
+ x = x.contiguous()
121
+ output = torch.empty_like(x)
122
+ n_elements = x.numel()
123
+ grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
124
+ _rtn_1x16s_fp4_kernel[grid](
125
+ x_ptr=x, amax_ptr=x.abs().max(), output_ptr=output,
126
+ n_elements=n_elements, scale_override=scale_override,
127
+ group_size=group_size, four_over_six=four_over_six,
128
+ )
129
+ return output
130
+
131
+
132
+
133
+ @triton.autotune(
134
+ configs=[
135
+ triton.Config({"BLOCK_SIZE": 64 * 32}),
136
+ triton.Config({"BLOCK_SIZE": 128 * 32}),
137
+ triton.Config({"BLOCK_SIZE": 256 * 32}),
138
+ triton.Config({"BLOCK_SIZE": 512 * 32}),
139
+ ],
140
+ key=[],
141
+ )
142
+ @triton.jit
143
+ def _eden_1x16s_fp4_kernel(
144
+ x_ptr, hadamard_matrix_ptr, current_amax_ptr, output_ptr, next_amax_ptr,
145
+ n_elements: tl.constexpr,
146
+ hadamard_dim: tl.constexpr,
147
+ scale_override: tl.constexpr,
148
+ group_size: tl.constexpr,
149
+ seed: int,
150
+ BLOCK_SIZE: tl.constexpr,
151
+ ):
152
+ pid = tl.program_id(0)
153
+ start_idx = pid * BLOCK_SIZE
154
+ offsets = start_idx + tl.arange(0, BLOCK_SIZE)
155
+ mask = offsets < n_elements
156
+ x_flat = tl.load(x_ptr + offsets, mask=mask, other=0.0)
157
+
158
+ offsets_hadamard = tl.arange(0, hadamard_dim * hadamard_dim)
159
+ hadamard_matrix = tl.load(hadamard_matrix_ptr + offsets_hadamard).reshape(hadamard_dim, hadamard_dim)
160
+ x = tl.reshape(x_flat, (BLOCK_SIZE // hadamard_dim, hadamard_dim))
161
+ x_had = tl.dot(x, hadamard_matrix)
162
+
163
+ tl.atomic_max(next_amax_ptr, tl.max(tl.abs(x_had)).to(tl.float32), sem="relaxed")
164
+
165
+ x_grouped = tl.reshape(x_had, (BLOCK_SIZE // group_size, group_size))
166
+
167
+ scales_max = 255.99
168
+ val_max = 6.0 / scale_override
169
+ amax = tl.load(current_amax_ptr)
170
+ s_dec = tl.where(amax == 0.0, 1.0, amax / scales_max / val_max)
171
+
172
+ s_dec_b = tl.max(tl.abs(x_grouped), axis=-1, keep_dims=True) / val_max
173
+ s_dec_b_e4m3 = (s_dec_b / s_dec).to(tl.float8e4nv).to(tl.float32)
174
+ s_dec_b_e4m3 = tl.where(s_dec_b_e4m3 == 0, 1.0, s_dec_b_e4m3)
175
+ x_scaled = x_grouped / (s_dec_b_e4m3 * s_dec)
176
+
177
+ x_scaled_abs = tl.abs(x_scaled)
178
+ x_scaled_sign = tl.where(x_scaled > 0, 1, -1)
179
+ x_fp4 = tl.where(
180
+ x_scaled_abs >= 5, 6,
181
+ tl.where(x_scaled_abs >= 3.5, 4,
182
+ tl.where(x_scaled_abs >= 2.5, 3,
183
+ tl.where(x_scaled_abs >= 1.75, 2,
184
+ tl.where(x_scaled_abs >= 1.25, 1.5,
185
+ tl.where(x_scaled_abs >= 0.75, 1,
186
+ tl.where(x_scaled_abs >= 0.25, 0.5,
187
+ 0))))))) * x_scaled_sign
188
+
189
+ x_scaled = tl.reshape(x_scaled, (BLOCK_SIZE // hadamard_dim, hadamard_dim))
190
+ x_fp4 = tl.reshape(x_fp4, (BLOCK_SIZE // hadamard_dim, hadamard_dim))
191
+
192
+ num = tl.sum(x_scaled * x_scaled, axis=-1, keep_dims=True)
193
+ denom = tl.sum(x_scaled * x_fp4, axis=-1, keep_dims=True)
194
+ correction = tl.where(denom == 0.0, 1.0, num / denom)
195
+
196
+ scales = tl.reshape(s_dec_b_e4m3, (BLOCK_SIZE // hadamard_dim, hadamard_dim // group_size))
197
+ corrected_scales = tl.reshape(scales * correction, (BLOCK_SIZE // group_size, 1))
198
+
199
+ bitscales = tl.cast(corrected_scales.to(tl.float8e4nv), tl.uint8, bitcast=True)
200
+ prevscale = tl.cast((bitscales - 1), tl.float8e4nv, bitcast=True).to(tl.float32)
201
+ currscale = tl.cast((bitscales), tl.float8e4nv, bitcast=True).to(tl.float32)
202
+ nextscale = tl.cast((bitscales + 1), tl.float8e4nv, bitcast=True).to(tl.float32)
203
+
204
+ up = tl.where(currscale > corrected_scales, currscale, nextscale)
205
+ down = tl.where(currscale > corrected_scales, prevscale, currscale)
206
+ prob_up = (corrected_scales - down) / (up - down)
207
+
208
+ scale_start_idx = pid * (BLOCK_SIZE // group_size)
209
+ scale_offsets = scale_start_idx + tl.arange(0, BLOCK_SIZE // group_size)
210
+ sampled_prob = tl.rand(seed, scale_offsets).reshape(BLOCK_SIZE // group_size, 1)
211
+
212
+ scales = tl.where(sampled_prob < prob_up, up, down)
213
+ scales = tl.reshape(scales, (BLOCK_SIZE // group_size, 1))
214
+ x_fp4 = tl.reshape(x_fp4, (BLOCK_SIZE // group_size, group_size))
215
+
216
+ x_dequantized = x_fp4 * scales * s_dec
217
+ x_dequantized_flat = tl.reshape(x_dequantized, (BLOCK_SIZE,))
218
+ tl.store(output_ptr + offsets, x_dequantized_flat.to(x_ptr.dtype.element_ty), mask=mask)
219
+
220
+
221
+ @torch.compiler.disable()
222
+ def eden_1x16s_fp4(x, hadamard_matrix, scale_override: float, group_size: int, current_amax):
223
+ hadamard_dim = hadamard_matrix.size(0)
224
+ x = x.contiguous()
225
+ hadamard_matrix = hadamard_matrix.T.contiguous()
226
+ output = torch.empty_like(x)
227
+ seed = randint(0, 1_000_000)
228
+ next_amax = torch.zeros_like(current_amax)
229
+ n_elements = x.numel()
230
+ grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
231
+ _eden_1x16s_fp4_kernel[grid](
232
+ x_ptr=x, hadamard_matrix_ptr=hadamard_matrix,
233
+ current_amax_ptr=current_amax, output_ptr=output,
234
+ next_amax_ptr=next_amax, n_elements=n_elements,
235
+ hadamard_dim=hadamard_dim, scale_override=scale_override,
236
+ group_size=group_size, seed=seed,
237
+ )
238
+ return output, next_amax
239
+
240
+
241
+
242
+ class AmaxStorage:
243
+ __slots__ = ("e_ht_amax", "weght_tht_amax", "e_tht_amax", "input_tht_amax")
244
+
245
+ def __init__(self):
246
+ self.e_ht_amax = None
247
+ self.weght_tht_amax = None
248
+ self.e_tht_amax = None
249
+ self.input_tht_amax = None
250
+
251
+
252
+
253
+ class FakeQuartetFn(torch.autograd.Function):
254
+ group_size = 16
255
+ forward_scale_override = 1.0
256
+ backward_scale_override = (17 / 16) * 0.93
257
+ hadamard_matrix = None
258
+
259
+ @torch.compile(dynamic=False)
260
+ @staticmethod
261
+ def forward(ctx, input, weight, amax_storage, delayed_amax, disable_forward_quant, disable_backward_quant, four_over_six):
262
+ ctx.batch = input.shape[0]
263
+ ctx.seq = input.shape[1]
264
+ ctx.in_dim = weight.shape[1]
265
+ ctx.out_dim = weight.shape[0]
266
+ ctx.delayed_amax = delayed_amax
267
+ ctx.amax_storage = amax_storage
268
+ ctx.disable_backward_quant = disable_backward_quant
269
+
270
+ if disable_forward_quant:
271
+ input_fq = input
272
+ weight_fq = weight
273
+ else:
274
+ input_fq = rtn_1x16s_fp4(input, FakeQuartetFn.forward_scale_override, FakeQuartetFn.group_size, four_over_six)
275
+ weight_fq = rtn_1x16s_fp4(weight, FakeQuartetFn.forward_scale_override, FakeQuartetFn.group_size, four_over_six)
276
+
277
+ ctx.save_for_backward(input_fq, weight_fq)
278
+ return F.linear(input_fq, weight_fq)
279
+
280
+ @staticmethod
281
+ def backward(ctx, grad_output):
282
+ input_fq, weight_fq = ctx.saved_tensors
283
+ dtype = grad_output.dtype
284
+ input_fq = input_fq.to(dtype).reshape(ctx.batch * ctx.seq, ctx.in_dim)
285
+ weight_fq = weight_fq.to(dtype)
286
+ grad_output = grad_output.reshape(ctx.batch * ctx.seq, ctx.out_dim)
287
+
288
+ FakeQuartetFn.hadamard_matrix = rerotate_hadamard(FakeQuartetFn.hadamard_matrix)
289
+
290
+ if ctx.disable_backward_quant:
291
+ grad_input = F.linear(grad_output, weight_fq.T, None).view(ctx.batch, ctx.seq, ctx.in_dim)
292
+ grad_weight = F.linear(grad_output.T, input_fq.T, None)
293
+ return grad_input, grad_weight, None, None, None, None, None
294
+
295
+ had = FakeQuartetFn.hadamard_matrix.to(grad_output.dtype)
296
+ bso = FakeQuartetFn.backward_scale_override
297
+ gs = FakeQuartetFn.group_size
298
+
299
+ # EW: grad_output @ weight^T
300
+ if ctx.amax_storage.e_ht_amax is None or not ctx.delayed_amax:
301
+ ctx.amax_storage.e_ht_amax = (grad_output.reshape(-1, had.size(0)) @ had.T).abs().max().float()
302
+ e_ht_fp4, ctx.amax_storage.e_ht_amax = eden_1x16s_fp4(grad_output, had, bso, gs, ctx.amax_storage.e_ht_amax)
303
+
304
+ if ctx.amax_storage.weght_tht_amax is None or not ctx.delayed_amax:
305
+ ctx.amax_storage.weght_tht_amax = (weight_fq.T.reshape(-1, had.size(0)) @ had.T).abs().max().float()
306
+ weight_tht_fp4, ctx.amax_storage.weght_tht_amax = eden_1x16s_fp4(weight_fq.T, had, bso, gs, ctx.amax_storage.weght_tht_amax)
307
+
308
+ grad_input = F.linear(e_ht_fp4, weight_tht_fp4, None).view(ctx.batch, ctx.seq, ctx.in_dim)
309
+
310
+ # EtX: grad_output^T @ input
311
+ if ctx.amax_storage.e_tht_amax is None or not ctx.delayed_amax:
312
+ ctx.amax_storage.e_tht_amax = (grad_output.T.reshape(-1, had.size(0)) @ had.T).abs().max().float()
313
+ e_tht_fp4, ctx.amax_storage.e_tht_amax = eden_1x16s_fp4(grad_output.T, had, bso, gs, ctx.amax_storage.e_tht_amax)
314
+
315
+ if ctx.amax_storage.input_tht_amax is None or not ctx.delayed_amax:
316
+ ctx.amax_storage.input_tht_amax = (input_fq.T.reshape(-1, had.size(0)) @ had.T).abs().max().float()
317
+ input_tht_fp4, ctx.amax_storage.input_tht_amax = eden_1x16s_fp4(input_fq.T, had, bso, gs, ctx.amax_storage.input_tht_amax)
318
+
319
+ grad_weight = F.linear(e_tht_fp4, input_tht_fp4, None)
320
+
321
+ return grad_input, grad_weight, None, None, None, None, None
322
+
323
+
324
+
325
+ class FakeQuartetLinear(torch.nn.Linear):
326
+
327
+ def __init__(self, *args, hadamard_dim=32, delayed_amax=False,
328
+ disable_forward_quant=False, disable_backward_quant=False,
329
+ four_over_six=True, **kwargs):
330
+ super().__init__(*args, **kwargs)
331
+ self.hadamard_dim = hadamard_dim
332
+ self.delayed_amax = delayed_amax
333
+ self.disable_forward_quant = disable_forward_quant
334
+ self.disable_backward_quant = disable_backward_quant
335
+ self.four_over_six = four_over_six
336
+ self.amax_storage = AmaxStorage()
337
+
338
+ if FakeQuartetFn.hadamard_matrix is None:
339
+ FakeQuartetFn.hadamard_matrix = get_hadamard_matrix(
340
+ self.hadamard_dim, dtype=torch.float32, device="cuda",
341
+ )
342
+
343
+ def forward(self, x):
344
+ return FakeQuartetFn.apply(
345
+ x, self.weight, self.amax_storage,
346
+ self.delayed_amax, self.disable_forward_quant,
347
+ self.disable_backward_quant, self.four_over_six,
348
+ )
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5802c11b6b024033386dba4cdff8665d48de19850e0e63c31686f44430ca870f
3
+ size 16563661264
modeling_cloverlm.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from math import sqrt
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from transformers import PreTrainedModel, GenerationMixin
8
+ from transformers.modeling_outputs import CausalLMOutputWithPast
9
+
10
+ from .configuration_cloverlm import CloverLMConfig
11
+ from .fake_quartet import FakeQuartetLinear
12
+
13
+
14
+
15
+ def _sphere_norm(X, dim=-1):
16
+ return F.normalize(X, dim=dim)
17
+
18
+
19
+ class _ReLU2(nn.Module):
20
+ def forward(self, x):
21
+ return F.relu(x) ** 2
22
+
23
+
24
+ def _make_linear(in_f, out_f, bias, quartet_2_impl):
25
+ if quartet_2_impl == "pseudoquant":
26
+ return FakeQuartetLinear(in_f, out_f, bias)
27
+ elif quartet_2_impl == "quartet2":
28
+ try:
29
+ from quartet2.linear import Quartet_II_linear
30
+ except ImportError as e:
31
+ e.add_note("Quartet_II_linear import failed. Install the latest quartet2 from https://github.com/IST-DASLab/Quartet-II")
32
+ raise e
33
+
34
+ return Quartet_II_linear(in_f, out_f, bias)
35
+ else:
36
+ raise ValueError(f"Unsupported quartet_2_impl: {quartet_2_impl}")
37
+
38
+
39
+ def _build_rope(context, d_head, device):
40
+ ms = torch.arange(context, device=device, dtype=torch.float32)
41
+ js = torch.arange(d_head // 2, device=device, dtype=torch.float32)
42
+ theta = 1.0 / (1024.0 ** (2.0 * js / d_head))
43
+ phi = ms[:, None] @ theta[None, :]
44
+ cos = torch.cos(phi).repeat_interleave(2, dim=1)
45
+ sin = torch.sin(phi).repeat_interleave(2, dim=1)
46
+ return torch.stack((cos, sin))
47
+
48
+
49
+ def _apply_rope(X, rope):
50
+ X_ = torch.empty_like(X)
51
+ X_[..., 0::2] = -X[..., 1::2]
52
+ X_[..., 1::2] = X[..., 0::2]
53
+ return (X * rope[0] + X_ * rope[1]).to(X.dtype)
54
+
55
+
56
+
57
+ class _MLP(nn.Module):
58
+
59
+ def __init__(self, d, d_hidden, quartet_2_impl):
60
+ super().__init__()
61
+ self.l1 = nn.Sequential(_make_linear(d, d_hidden, False, quartet_2_impl), _ReLU2())
62
+ self.l2 = _make_linear(d_hidden, d, False, quartet_2_impl)
63
+
64
+ def forward(self, x):
65
+ return self.l2(self.l1(x))
66
+
67
+
68
+
69
+ class MHSA(nn.Module):
70
+ def __init__(self, heads, d_head, ratio, quartet_2_impl):
71
+ super().__init__()
72
+ self.heads = heads
73
+ self.d_head = d_head
74
+ self.d = heads * d_head
75
+ self.groups = heads // ratio
76
+ d_kv = self.groups * d_head
77
+
78
+ self.lq = _make_linear(self.d, self.d, False, quartet_2_impl)
79
+ self.lk = _make_linear(self.d, d_kv, False, quartet_2_impl)
80
+ self.lv = _make_linear(self.d, d_kv, False, quartet_2_impl)
81
+ self.lo = _make_linear(self.d, self.d, False, quartet_2_impl)
82
+
83
+ self.scale = nn.Parameter(torch.full((1, heads, 1, 1), sqrt(d_head)))
84
+
85
+ def forward(self, X, rope, attn_backend):
86
+ B = X.shape[0] if X.dim() == 3 else 1
87
+ ctx = X.shape[-2]
88
+
89
+ Q = self.lq(X).unflatten(-1, (self.heads, self.d_head)).movedim(-3, -2)
90
+ K = self.lk(X).unflatten(-1, (self.groups, self.d_head)).movedim(-3, -2)
91
+ V = self.lv(X).unflatten(-1, (self.groups, self.d_head)).movedim(-3, -2)
92
+
93
+ Q = _apply_rope(Q, rope)
94
+ K = _apply_rope(K, rope)
95
+ Q = _sphere_norm(Q)
96
+ K = _sphere_norm(K)
97
+
98
+ Q_shape = Q.shape
99
+ Q = self.scale * Q
100
+ Q = Q.reshape(Q_shape)
101
+
102
+ if attn_backend == "pytorch":
103
+ K = K.repeat_interleave(self.heads // self.groups, dim=-3)
104
+ V = V.repeat_interleave(self.heads // self.groups, dim=-3)
105
+ Y = F.scaled_dot_product_attention(Q, K, V, is_causal=True, scale=1.0)
106
+ Y = Y.movedim(-3, -2).flatten(-2, -1)
107
+ elif attn_backend in ("flash2", "flash3", "flash4"):
108
+ Q = Q.movedim(-3, -2).reshape(-1, ctx, self.heads, self.d_head)
109
+ K = K.movedim(-3, -2).reshape(-1, ctx, self.groups, self.d_head)
110
+ V = V.movedim(-3, -2).reshape(-1, ctx, self.groups, self.d_head)
111
+
112
+ dtype = Q.dtype if Q.dtype in (torch.bfloat16, torch.float16) else torch.bfloat16
113
+ if attn_backend == "flash2":
114
+ import flash_attn
115
+ Y = flash_attn.flash_attn_func(Q.to(dtype), K.to(dtype), V.to(dtype), causal=True, softmax_scale=1.0)
116
+ elif attn_backend == "flash3":
117
+ import importlib
118
+ _fa3 = importlib.import_module("flash_attn_interface")
119
+ Y = _fa3.flash_attn_func(Q.to(dtype), K.to(dtype), V.to(dtype), causal=True, softmax_scale=1.0)
120
+ elif attn_backend == "flash4":
121
+ import importlib
122
+ _fa4 = importlib.import_module("flash_attn.cute")
123
+ Y = _fa4.flash_attn_func(Q.to(dtype), K.to(dtype), V.to(dtype), causal=True, softmax_scale=1.0)[0]
124
+ Y = Y.to(Q.dtype).flatten(-2, -1)
125
+
126
+ return self.lo(Y)
127
+
128
+
129
+
130
+ class _Block(nn.Module):
131
+
132
+ def __init__(self, heads, d_head, ratio, quartet_2_impl):
133
+ super().__init__()
134
+ d = heads * d_head
135
+
136
+ self.mhsa = MHSA(heads, d_head, ratio, quartet_2_impl)
137
+ self.out_att_norm = nn.RMSNorm(d, elementwise_affine=True)
138
+
139
+ self.mlp = _MLP(d, 4 * d, quartet_2_impl)
140
+ self.out_mlp_norm = nn.RMSNorm(d, elementwise_affine=True)
141
+
142
+ def forward(self, X, rope, attn_backend):
143
+ Y = self.out_att_norm(self.mhsa(X, rope, attn_backend))
144
+ Y = X + Y
145
+ Z = self.out_mlp_norm(self.mlp(Y))
146
+ return Y + Z
147
+
148
+
149
+
150
+ class _Transformer(nn.Module):
151
+
152
+ def __init__(self, vocab_size, num_blocks, heads, d_head, ratio,
153
+ max_context, std, quartet_2_impl, weight_tying, attn_backend):
154
+ super().__init__()
155
+ self.d_head = d_head
156
+ self.attn_backend = attn_backend
157
+ d = heads * d_head
158
+
159
+ self.emb = nn.Embedding(vocab_size, d)
160
+ self.blocks = nn.Sequential(*[
161
+ _Block(heads, d_head, ratio, quartet_2_impl) for _ in range(num_blocks)
162
+ ])
163
+ self.out_norm = nn.RMSNorm(d, elementwise_affine=True)
164
+ self.linear = nn.Linear(d, vocab_size, bias=False)
165
+
166
+ if weight_tying:
167
+ self.emb.weight = self.linear.weight
168
+
169
+ for name, p in self.named_parameters():
170
+ parent_name, _, suffix = name.rpartition(".")
171
+ parent = self.get_submodule(parent_name)
172
+ if isinstance(parent, (nn.Linear, nn.Embedding)) and suffix == "weight":
173
+ nn.init.normal_(p, 0, std)
174
+ elif isinstance(parent, nn.RMSNorm) and suffix == "weight":
175
+ nn.init.ones_(p)
176
+ elif p.ndim == 4:
177
+ nn.init.constant_(p, sqrt(d_head))
178
+
179
+ if quartet_2_impl:
180
+ for m in self.modules():
181
+ if isinstance(m, (nn.LayerNorm, nn.RMSNorm, nn.Embedding)):
182
+ m.to(torch.bfloat16)
183
+
184
+ def forward(self, ids):
185
+ ctx = ids.shape[-1]
186
+ rope = _build_rope(ctx, self.d_head, device=ids.device)
187
+
188
+ X = self.emb(ids)
189
+ for block in self.blocks:
190
+ X = block(X, rope, self.attn_backend)
191
+ X = self.out_norm(X)
192
+ return self.linear(X)
193
+
194
+
195
+
196
+ class CloverLMForCausalLM(PreTrainedModel, GenerationMixin):
197
+ config_class = CloverLMConfig
198
+ supports_gradient_checkpointing = False
199
+ _no_split_modules = ["_Block"]
200
+ _tied_weights_keys = ["transformer.linear.weight"]
201
+ _tp_plan = {}
202
+
203
+ def __init__(self, config: CloverLMConfig):
204
+ super().__init__(config)
205
+ self.all_tied_weights_keys = {k: "transformer.emb.weight"
206
+ for k in (self._tied_weights_keys or [])}
207
+ self.transformer = _Transformer(
208
+ vocab_size=config.vocab_size,
209
+ num_blocks=config.num_blocks,
210
+ heads=config.heads,
211
+ d_head=config.d_head,
212
+ ratio=config.ratio,
213
+ max_context=config.max_context,
214
+ std=0.02,
215
+ quartet_2_impl=config.quartet_2_impl,
216
+ weight_tying=config.weight_tying,
217
+ attn_backend=config.attn_backend,
218
+ )
219
+
220
+ def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
221
+ logits = self.transformer(input_ids)
222
+
223
+ loss = None
224
+ if labels is not None:
225
+ shift_logits = logits[..., :-1, :].contiguous()
226
+ shift_labels = labels[..., 1:].contiguous()
227
+ loss = F.cross_entropy(
228
+ shift_logits.view(-1, shift_logits.size(-1)),
229
+ shift_labels.view(-1),
230
+ )
231
+ return CausalLMOutputWithPast(loss=loss, logits=logits)
232
+
233
+ def prepare_inputs_for_generation(self, input_ids, **kwargs):
234
+ return {"input_ids": input_ids}
235
+
236
+ def _supports_default_dynamic_cache(self):
237
+ return False
tokenization_cloverlm.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import List, Optional
3
+ import tokenmonster
4
+ from transformers import PreTrainedTokenizer
5
+
6
+
7
+ TOKENMONSTER_URL = (
8
+ "https://huggingface.co/gvlassis/tokenmonster/resolve/main/"
9
+ "englishcode-32000-strict-nocapcode-v1-eot%3D14199.vocab"
10
+ "?download=true"
11
+ )
12
+
13
+
14
+ class CloverLMTokenizer(PreTrainedTokenizer):
15
+ model_input_names = ["input_ids", "attention_mask"]
16
+
17
+ def __init__(self, vocab_url: str = TOKENMONSTER_URL,
18
+ eot_id: int = 14199, **kwargs):
19
+ self._tm = tokenmonster.load(vocab_url)
20
+ self._eot_id = eot_id
21
+ self._vocab_size = 32000
22
+
23
+ super().__init__(
24
+ eos_token="<eot>",
25
+ pad_token="<eot>",
26
+ bos_token="<eot>",
27
+ **kwargs,
28
+ )
29
+ self.eos_token_id = eot_id
30
+ self.pad_token_id = eot_id
31
+ self.bos_token_id = eot_id
32
+
33
+ @property
34
+ def vocab_size(self) -> int:
35
+ return self._vocab_size
36
+
37
+ def get_vocab(self):
38
+ return {f"<tok_{i}>": i for i in range(self._vocab_size)}
39
+
40
+ def _tokenize(self, text: str, **kwargs) -> List[str]:
41
+ ids = self._tm.tokenize(text).tolist()
42
+ return [str(i) for i in ids]
43
+
44
+ def _convert_token_to_id(self, token: str) -> int:
45
+ return int(token)
46
+
47
+ def _convert_id_to_token(self, index: int) -> str:
48
+ return str(index)
49
+
50
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
51
+ ids = [int(t) for t in tokens]
52
+ return self._tm.decode(ids)
53
+
54
+ @property
55
+ def all_special_tokens_extended(self):
56
+ return [self.eos_token]
57
+
58
+ @property
59
+ def all_special_tokens(self):
60
+ return [self.eos_token]
61
+
62
+ @property
63
+ def all_special_ids(self):
64
+ return [self._eot_id]
65
+
66
+ def save_vocabulary(self, save_directory: str,
67
+ filename_prefix: Optional[str] = None):
68
+ return ()
tokenizer_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "tokenizer_class": "CloverLMTokenizer",
3
+ "auto_map": {
4
+ "AutoTokenizer": [
5
+ "tokenization_cloverlm.CloverLMTokenizer",
6
+ null
7
+ ]
8
+ },
9
+ "use_fast": false
10
+ }