asd52403 commited on
Commit
da67f74
·
1 Parent(s): 5898447

init commit

Browse files
inference/configs/config_16B.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab_size": 102400,
3
+ "dim": 2048,
4
+ "inter_dim": 10944,
5
+ "moe_inter_dim": 1408,
6
+ "n_layers": 27,
7
+ "n_dense_layers": 1,
8
+ "n_heads": 16,
9
+ "n_routed_experts": 64,
10
+ "n_shared_experts": 2,
11
+ "n_activated_experts": 6,
12
+ "route_scale": 1.0,
13
+ "q_lora_rank": 0,
14
+ "kv_lora_rank": 512,
15
+ "qk_nope_head_dim": 128,
16
+ "qk_rope_head_dim": 64,
17
+ "v_head_dim": 128,
18
+ "mscale": 0.707
19
+ }
inference/configs/config_236B.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab_size": 102400,
3
+ "dim": 5120,
4
+ "inter_dim": 12288,
5
+ "moe_inter_dim": 1536,
6
+ "n_layers": 60,
7
+ "n_dense_layers": 1,
8
+ "n_heads": 128,
9
+ "n_routed_experts": 160,
10
+ "n_shared_experts": 2,
11
+ "n_activated_experts": 6,
12
+ "n_expert_groups": 8,
13
+ "n_limited_groups": 3,
14
+ "route_scale": 16.0,
15
+ "q_lora_rank": 1536,
16
+ "kv_lora_rank": 512,
17
+ "qk_nope_head_dim": 128,
18
+ "qk_rope_head_dim": 64,
19
+ "v_head_dim": 128
20
+ }
inference/configs/config_671B.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab_size": 129280,
3
+ "dim": 7168,
4
+ "inter_dim": 18432,
5
+ "moe_inter_dim": 2048,
6
+ "n_layers": 61,
7
+ "n_dense_layers": 3,
8
+ "n_heads": 128,
9
+ "n_routed_experts": 256,
10
+ "n_shared_experts": 1,
11
+ "n_activated_experts": 8,
12
+ "n_expert_groups": 8,
13
+ "n_limited_groups": 4,
14
+ "route_scale": 2.5,
15
+ "score_func": "sigmoid",
16
+ "q_lora_rank": 1536,
17
+ "kv_lora_rank": 512,
18
+ "qk_nope_head_dim": 128,
19
+ "qk_rope_head_dim": 64,
20
+ "v_head_dim": 128,
21
+ "dtype": "fp8"
22
+ }
inference/configs/config_671B_test.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab_size": 8080,
3
+ "dim": 7168,
4
+ "inter_dim": 1152,
5
+ "moe_inter_dim": 2048,
6
+ "n_layers": 61,
7
+ "n_dense_layers": 3,
8
+ "n_heads": 8,
9
+ "n_routed_experts": 16,
10
+ "n_shared_experts": 1,
11
+ "n_activated_experts": 8,
12
+ "n_expert_groups": 8,
13
+ "n_limited_groups": 4,
14
+ "route_scale": 2.5,
15
+ "score_func": "sigmoid",
16
+ "q_lora_rank": 1536,
17
+ "kv_lora_rank": 512,
18
+ "qk_nope_head_dim": 128,
19
+ "qk_rope_head_dim": 64,
20
+ "v_head_dim": 128,
21
+ "dtype": "fp8"
22
+ }
23
+
inference/convert2.py ADDED
@@ -0,0 +1,630 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ from argparse import ArgumentParser
4
+ from glob import glob
5
+ from tqdm import tqdm, trange
6
+
7
+ import torch
8
+ import ctypes
9
+ from safetensors.torch import safe_open, save_file
10
+ from kernel import weight_dequant
11
+
12
+
13
+ mapping = {
14
+ "embed_tokens": ("embed", 0),
15
+ "input_layernorm": ("attn_norm", None),
16
+ "post_attention_layernorm": ("ffn_norm", None),
17
+ "q_proj": ("wq", 0),
18
+ "q_a_proj": ("wq_a", None),
19
+ "q_a_layernorm": ("q_norm", None),
20
+ "q_b_proj": ("wq_b", 0),
21
+ "kv_a_proj_with_mqa": ("wkv_a", None),
22
+ "kv_a_layernorm": ("kv_norm", None),
23
+ "kv_b_proj": ("wkv_b", 0),
24
+ "o_proj": ("wo", 1),
25
+ "gate": ("gate", None),
26
+ "gate_proj": ("w1", 0),
27
+ "down_proj": ("w2", 1),
28
+ "up_proj": ("w3", 0),
29
+ "norm": ("norm", None),
30
+ "lm_head": ("head", 0),
31
+ "scale": ("scale", None),
32
+ }
33
+
34
+ EmbedsInOneFile = 256
35
+ EmbedsZKDir = "zkdata/embeds/"
36
+
37
+ wkv_b_1_rescales = [32, 34, 37, 36, 33, 32, 33, 33, 30, 32,
38
+ 32, 30, 31, 30, 29, 30, 29, 30, 29, 29,
39
+ 29, 29, 29, 29, 29, 29, 29, 29, 29, 29,
40
+ 29, 29, 29, 29, 29, 29, 29, 29, 29, 29,
41
+ 29, 29, 29, 29, 29, 29, 29, 29, 30, 30,
42
+ 29, 29, 30, 30, 30, 30, 29, 30, 30, 29, 30]
43
+
44
+ wkv_b_2_rescales = [31, 32, 32, 31, 32, 30, 30, 30, 30, 30,
45
+ 30, 30, 30, 29, 29, 29, 29, 30, 29, 29,
46
+ 29, 29, 29, 29, 30, 30, 30, 29, 29, 29,
47
+ 29, 29, 30, 29, 30, 29, 30, 29, 29, 29,
48
+ 30, 29, 29, 29, 29, 30, 29, 30, 30, 30,
49
+ 29, 29, 29, 30, 30, 29, 29, 29, 30, 30, 30]
50
+
51
+ wo_rescales = [31, 32, 32, 32, 32, 31, 32, 31, 31, 31,
52
+ 31, 31, 31, 31, 30, 31, 31, 32, 31, 31,
53
+ 31, 30, 30, 30, 30, 30, 30, 30, 30, 30,
54
+ 30, 30, 30, 30, 30, 30, 30, 30, 30, 30,
55
+ 30, 30, 30, 31, 30, 31, 30, 30, 31, 31,
56
+ 31, 30, 31, 31, 31, 30, 31, 31, 31, 31, 32 ]
57
+
58
+ gate_rescales = [0, 0, 0, 33, 32, 32, 32, 31, 32, 31, 30,
59
+ 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
60
+ 32, 31, 32, 31, 32, 32, 32, 32, 31, 32,
61
+ 32, 31, 32, 32, 32, 32, 32, 32, 32, 32,
62
+ 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
63
+ 32, 32, 32, 33, 33, 33, 33, 33, 32, 32 ]
64
+
65
+ w1_rescales = [32, 32, 32]
66
+ w2_rescales = [31, 32, 31]
67
+ w3_rescales = [32, 33, 32]
68
+
69
+ shared_w1_rescales = [0, 0, 0, 30, 30, 29, 29, 29, 28, 29,
70
+ 29, 28, 29, 29, 29, 29, 29, 29, 29, 29,
71
+ 29, 29, 29, 30, 30, 30, 30, 30, 30, 30,
72
+ 30, 30, 30, 30, 29, 29, 30, 29, 29, 30,
73
+ 29, 29, 29, 29, 29, 29, 29, 29, 29, 29,
74
+ 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29]
75
+
76
+ shared_w2_rescales = [0, 0, 0, 30, 30, 30, 30, 30, 29, 29,
77
+ 30, 29, 29, 29, 30, 30, 30, 30, 30, 29,
78
+ 29, 29, 29, 29, 29, 29, 29, 30, 30, 29,
79
+ 29, 29, 29, 29, 29, 29, 29, 30, 29, 29,
80
+ 29, 29, 29, 29, 29, 29, 29, 29, 29, 29,
81
+ 29, 29, 29, 29, 29, 29, 30, 30, 30, 30, 30]
82
+
83
+ shared_w3_rescales = [0, 0, 0, 30, 30, 30, 30, 30, 29, 29,
84
+ 30, 29, 29, 29, 30, 30, 30, 29, 30, 29,
85
+ 29, 29, 29, 29, 29, 29, 30, 30, 30, 30,
86
+ 29, 29, 29, 29, 29, 29, 29, 30, 30, 29,
87
+ 30, 29, 29, 29, 29, 30, 29, 29, 30, 30,
88
+ 29, 30, 30, 30, 29, 29, 30, 30, 30, 29, 28]
89
+
90
+ layer_state_dict0 = [{} for _ in range(61)]
91
+ layer_state_dict = [{} for _ in range(61)]
92
+
93
+ experts = [ [{} for _j in range(256)] for _i in range(61)]
94
+
95
+ def getF32PrintStr(ele):
96
+ v = int(ele.cpu().view(torch.uint32).item())
97
+ ex = str((v >> 23 & 0xFF) - 127)
98
+ r = '(1+' + str(v & 0x7FFFFF) + '/8388608)'
99
+ if v & 0x80000000:
100
+ vstr = '-' + r + '*2^' + ex
101
+ else:
102
+ vstr = r + '*2^' + ex
103
+ return vstr
104
+
105
+ def getBF16PrintStr(ele):
106
+ v = int(ele.cpu().view(torch.uint16).item())
107
+ ex = v >> 7 & 0xFF
108
+ r = '(1+' + str(v & 0x7F) + '/128)'
109
+ rraw = v & 0x7F
110
+
111
+ if v & 0x8000:
112
+ vstr = '-' + r + '*2^' + str(ex - 127)
113
+ else:
114
+ vstr = r + '*2^' + str(ex - 127)
115
+ return vstr
116
+
117
+ def getBF8PrintStr(ele):
118
+ v = int(ele.cpu().view(torch.uint8).item())
119
+ ex = str((v >> 3 & 0xF) - 7)
120
+ r = '(1+' + str(v & 0x7) + '/8)'
121
+
122
+ if v & 0x80:
123
+ vstr = '-' + r + '*2^' + ex
124
+ else:
125
+ vstr = r + '*2^' + ex
126
+
127
+ if ex == -7 or ex == 8:
128
+ print(vstr)
129
+ return vstr
130
+
131
+ def mem(i):
132
+ a = torch.cuda.memory_allocated()/1024**2
133
+ r = torch.cuda.memory_reserved()/1024**2
134
+ m = torch.cuda.max_memory_allocated()/1024**2
135
+ print(f"{i} allocated={a:.1f}MB, reserved={r:.1f}MB, max={m:.1f}MB", flush=True)
136
+
137
+ def handle_expert_w(layer_id, expert_id, idx, param_weight, weight_name, scale, typ, shape, experts_save_path):
138
+ global layer_state_dict0
139
+ global experts
140
+
141
+ scale_name = weight_name.replace('weight', 'scale')
142
+ param_scale = layer_state_dict0[layer_id][scale_name]
143
+
144
+ weight = weight_dequant(param_weight.cuda(), param_scale.cuda())
145
+ # scale = experts_w3_rescales[layer_id][expert_id]
146
+ rescale = 2 ** scale
147
+ param_int = (weight.to(torch.float32) * rescale).round().to(torch.int32)
148
+ # layer_state_dict[layer_id][weight_name] = param_int.cpu()
149
+ # layer_state_dict[layer_id][scale_name] = torch.tensor(scale, dtype=torch.int32)
150
+ weight_name2 = f'w{idx}.weight'
151
+ scale_name2 = f'w{idx}.scale'
152
+ experts[layer_id][expert_id][weight_name2] = param_int
153
+ experts[layer_id][expert_id][scale_name2] = torch.tensor(scale, dtype=torch.int32)
154
+
155
+ if len(experts[layer_id][expert_id]) == 6: # w1, w2, w3 以及对应的 scale
156
+ save_file(experts[layer_id][expert_id], os.path.join(experts_save_path, f"{expert_id}.safetensors"))
157
+ experts[layer_id][expert_id] = {}
158
+
159
+ print(f'layer {layer_id} expert {expert_id} w{idx} type: {typ}, shape: {shape}, weight_name: {weight_name}, scale_name: {scale_name}')
160
+
161
+ def saveTensor(fileName, t):
162
+ with open(fileName, "w", encoding="utf-8") as f:
163
+ t = t.detach()
164
+ if t.device.type != "cpu":
165
+ t = t.cpu()
166
+ t = t.contiguous()
167
+ with open(fileName, "wb") as f:
168
+ f.write(t.numpy().tobytes(order="C"))
169
+
170
+ def main(hf_ckpt_path, save_path, n_experts, mp):
171
+ """
172
+ Converts and saves model checkpoint files into a specified format.
173
+
174
+ Args:
175
+ hf_ckpt_path (str): Path to the directory containing the input checkpoint files.
176
+ save_path (str): Path to the directory where the converted checkpoint files will be saved.
177
+ n_experts (int): Total number of experts in the model.
178
+ mp (int): Model parallelism factor.
179
+
180
+ Returns:
181
+ None
182
+ """
183
+ torch.cuda.set_device(0)
184
+ # 设置pytorch计算时的默认数据类型。这里使用的是BF16
185
+ torch.set_default_dtype(torch.bfloat16)
186
+ # 限制 PyTorch 在 CPU 计算时最多使用 8 个线程,防止过多线程竞争资源:
187
+ torch.set_num_threads(8)
188
+ # 设定随机种子,保证不同进程初始化时随机数相同。
189
+ torch.manual_seed(965)
190
+
191
+ # n_local_experts = n_experts // mp
192
+ # state_dicts = [{} for _ in range(mp)]
193
+
194
+ head_state_dict = {}
195
+ norm_state_dict = {}
196
+ embed_state_dict = {}
197
+
198
+ experts_w1_rescales = []
199
+ experts_w2_rescales = []
200
+ experts_w3_rescales = []
201
+
202
+ with open("w1.txt", "r", encoding="utf-8") as f1:
203
+ for line in f1:
204
+ layer_line = line.strip().split()
205
+ int_list = [int(s) for s in layer_line]
206
+ experts_w1_rescales.append(int_list)
207
+
208
+ with open("w2.txt", "r", encoding="utf-8") as f2:
209
+ for line in f2:
210
+ layer_line = line.strip().split()
211
+ int_list = [int(s) for s in layer_line]
212
+ experts_w2_rescales.append(int_list)
213
+
214
+ with open("w3.txt", "r", encoding="utf-8") as f3:
215
+ for line in f3:
216
+ layer_line = line.strip().split()
217
+ int_list = [int(s) for s in layer_line]
218
+ experts_w3_rescales.append(int_list)
219
+
220
+ # Tqdm 是一个快速,可扩展的Python进度条,可以在 Python 长循环中添加一个进度提示信息,用户只需要封装任意的迭代器 tqdm(iterator)。
221
+ # glob是python自己带的一个文件操作相关模块,用它可以查找符合自己目的的文件,类似于Windows下的文件搜索
222
+ for file_path in tqdm(glob(os.path.join(hf_ckpt_path, "*.safetensors"))):
223
+ with safe_open(file_path, framework="pt", device="cpu") as f:
224
+ print('Opening ' + file_path, flush=True)
225
+ for name in f.keys():
226
+ # print('name 1: ', name)
227
+ if "model.layers.61" in name:
228
+ continue
229
+
230
+ param: torch.Tensor = f.get_tensor(name)
231
+ if name.startswith("model."):
232
+ name = name[len("model."):]
233
+ name = name.replace("self_attn", "attn")
234
+ name = name.replace("mlp", "ffn")
235
+ name = name.replace("weight_scale_inv", "scale")
236
+ name = name.replace("e_score_correction_bias", "bias")
237
+ key = name.split(".")[-2]
238
+ assert key in mapping, f"Key {key} not found in mapping"
239
+ # print('key::: ' + key)
240
+ new_key, dim = mapping[key]
241
+ # print('dim::: ' + str(dim))
242
+ name = name.replace(key, new_key)
243
+
244
+ ns = name.split(".")
245
+ comp = ns[0]
246
+ if comp == 'head':
247
+ name2 = name[len('head.'):]
248
+ print('head: ' + name2)
249
+
250
+ param_int = (param.to(torch.float32) * (2 ** 43)).round().to(torch.int64)
251
+ head_state_dict[name2] = param_int
252
+ elif comp == 'norm':
253
+ name2 = name[len('norm.'):]
254
+ print('norm: ' + name2)
255
+
256
+ param_int = (param.to(torch.float32) * (2 ** 15)).round().to(torch.int64)
257
+ norm_state_dict[name2] = param_int
258
+ elif comp == 'embed':
259
+ name2 = name[len('embed.'):]
260
+ print('embed: ' + name2)
261
+
262
+ param_int = (param.to(torch.float32) * (2 ** 31)).round().to(torch.int64)
263
+ embed_state_dict[name2] = param_int
264
+
265
+ os.makedirs(EmbedsZKDir, exist_ok=True)
266
+ fileCount = param_int.shape[0] // EmbedsInOneFile
267
+ for i in range(0, fileCount):
268
+ saveTensor(EmbedsZKDir + str(i) + '.bin', param_int[i * EmbedsInOneFile : (i+1) * EmbedsInOneFile].cpu())
269
+ elif comp == 'layers':
270
+ layer_id = int(ns[1])
271
+ name2 = '.'.join(ns[2:])
272
+ layer_state_dict0[layer_id][name2] = param
273
+
274
+ print('Finish loading state dict from disk! ++++++++++')
275
+
276
+ # for layer_id, states in enumerate(layer_state_dict0):
277
+ for layer_id in range(len(layer_state_dict0)):
278
+ os.makedirs(f'{save_path}/experts-{layer_id}', exist_ok=True)
279
+
280
+ states = layer_state_dict0[layer_id]
281
+
282
+ for name, param in states.items():
283
+ ns = name.split(".")
284
+ typ = param.type()
285
+ shape = param.shape
286
+
287
+ if ns[0] == 'attn_norm':
288
+ print(f'layer {layer_id} {name}, type: {typ}', flush=True)
289
+ if ns[1] == 'weight':
290
+ param_int = (param.to(torch.float32) * (2 ** 21)).round().to(torch.int32)
291
+ layer_state_dict[layer_id][name] = param_int
292
+ elif ns[0] == 'ffn_norm':
293
+ print(f'layer {layer_id} {name}, type: {typ}', flush=True)
294
+ if ns[1] == 'weight':
295
+ param_int2 = (param.to(torch.float32) * (2 ** 23)).round().to(torch.int32)
296
+ layer_state_dict[layer_id][name] = param_int2
297
+ elif ns[0] == 'ffn':
298
+ if len(ns) == 3:
299
+ if ns[1] == 'w1' and ns[2] == 'scale':
300
+ continue
301
+ elif ns[1] == 'w1' and ns[2] == 'weight':
302
+ param_weight = param.cuda()
303
+ weight_name = name
304
+
305
+ scale_name = name.replace('weight', 'scale')
306
+ param_scale = states[scale_name]
307
+
308
+ weight = weight_dequant(param_weight, param_scale.cuda())
309
+ scale = w1_rescales[layer_id]
310
+ rescale = 2 ** scale
311
+ param_int = (weight.to(torch.float32) * rescale).round().to(torch.int32)
312
+ layer_state_dict[layer_id][weight_name] = param_int.cpu()
313
+ layer_state_dict[layer_id][scale_name] = torch.tensor(scale, dtype=torch.int32)
314
+
315
+ print(f'layer {layer_id} w1 weight, type: {typ}, shape: {shape}, weight_name: {weight_name}, scale_name: {name}', flush=True)
316
+ elif ns[1] == 'w2' and ns[2] == 'scale':
317
+ continue
318
+ elif ns[1] == 'w2' and ns[2] == 'weight':
319
+ param_weight = param.cuda()
320
+ weight_name = name
321
+
322
+ scale_name = name.replace('weight', 'scale')
323
+ param_scale = states[scale_name]
324
+
325
+ weight = weight_dequant(param_weight, param_scale.cuda())
326
+ scale = w2_rescales[layer_id]
327
+ rescale = 2 ** scale
328
+ param_int = (weight.to(torch.float32) * rescale).round().to(torch.int32)
329
+ layer_state_dict[layer_id][weight_name] = param_int.cpu()
330
+ layer_state_dict[layer_id][scale_name] = torch.tensor(scale, dtype=torch.int32)
331
+
332
+ print(f'layer {layer_id} w2 weight, type: {typ}, shape: {shape}, weight_name: {weight_name}, scale_name: {name}', flush=True)
333
+ elif ns[1] == 'w3' and ns[2] == 'scale':
334
+ continue
335
+ elif ns[1] == 'w3' and ns[2] == 'weight':
336
+ param_weight = param.cuda()
337
+ weight_name = name
338
+
339
+ scale_name = name.replace('weight', 'scale')
340
+ param_scale = states[scale_name]
341
+
342
+ weight = weight_dequant(param_weight, param_scale.cuda())
343
+ scale = w3_rescales[layer_id]
344
+ rescale = 2 ** scale
345
+ param_int = (weight.to(torch.float32) * rescale).round().to(torch.int32)
346
+ layer_state_dict[layer_id][weight_name] = param_int.cpu()
347
+ layer_state_dict[layer_id][scale_name] = torch.tensor(scale, dtype=torch.int32)
348
+
349
+ print(f'layer {layer_id} w3 weight, type: {typ}, shape: {shape}, weight_name: {weight_name}, scale_name: {name}', flush=True)
350
+
351
+ elif ns[1] == 'gate' and ns[2] == 'weight':
352
+ gate_rescale = 2 ** gate_rescales[layer_id]
353
+ gate_int = (param.to(torch.float32) * gate_rescale).round().to(torch.int32)
354
+ layer_state_dict[layer_id][name] = gate_int.cpu()
355
+ rescale_name = name.replace('weight', 'scale')
356
+ layer_state_dict[layer_id][rescale_name] = torch.tensor(gate_rescales[layer_id], dtype=torch.int32)
357
+ print(f'layer {layer_id}: gate_weight_name: {name}, gate_scale_name: {rescale_name}')
358
+ elif ns[1] == 'gate' and ns[2] == 'bias':
359
+ bias_int = (param.to(torch.float32) * (2 ** 23)).round().to(torch.int32)
360
+ layer_state_dict[layer_id][name] = bias_int.cpu()
361
+ print(f'layer {layer_id} bias: {name}')
362
+ else:
363
+ layer_state_dict[layer_id][name] = param
364
+ elif len(ns) == 4:
365
+ if ns[1] == 'shared_experts':
366
+ if (ns[2] == 'w1' or ns[2] == 'w2' or ns[2] == 'w3') and ns[3] == 'scale':
367
+ continue
368
+ elif ns[2] == 'w1' and ns[3] == 'weight':
369
+ param_weight = param.cuda()
370
+ weight_name = name
371
+
372
+ scale_name = name.replace('weight', 'scale')
373
+ param_scale = states[scale_name]
374
+
375
+ weight = weight_dequant(param_weight, param_scale.cuda())
376
+ scale = shared_w1_rescales[layer_id]
377
+ rescale = 2 ** scale
378
+ param_int = (weight.to(torch.float32) * rescale).round().to(torch.int32)
379
+ layer_state_dict[layer_id][weight_name] = param_int.cpu()
380
+ layer_state_dict[layer_id][scale_name] = torch.tensor(scale, dtype=torch.int32)
381
+ print(f'layer {layer_id} shared_expert w1 type: {typ}, shape: {shape}, weight_name: {weight_name}, scale_name: {scale_name}')
382
+ elif ns[2] == 'w2' and ns[3] == 'weight':
383
+ param_weight = param.cuda()
384
+ weight_name = name
385
+
386
+ scale_name = name.replace('weight', 'scale')
387
+ param_scale = states[scale_name]
388
+
389
+ weight = weight_dequant(param_weight, param_scale.cuda())
390
+ scale = shared_w2_rescales[layer_id]
391
+ rescale = 2 ** scale
392
+ param_int = (weight.to(torch.float32) * rescale).round().to(torch.int32)
393
+ layer_state_dict[layer_id][weight_name] = param_int.cpu()
394
+ layer_state_dict[layer_id][scale_name] = torch.tensor(scale, dtype=torch.int32)
395
+ print(f'layer {layer_id} shared_expert w2 type: {typ}, shape: {shape}, weight_name: {weight_name}, scale_name: {scale_name}')
396
+ elif ns[2] == 'w3' and ns[3] == 'weight':
397
+ param_weight = param.cuda()
398
+ weight_name = name
399
+
400
+ scale_name = name.replace('weight', 'scale')
401
+ param_scale = states[scale_name]
402
+
403
+ weight = weight_dequant(param_weight, param_scale.cuda())
404
+ scale = shared_w3_rescales[layer_id]
405
+ rescale = 2 ** scale
406
+ param_int = (weight.to(torch.float32) * rescale).round().to(torch.int32)
407
+ layer_state_dict[layer_id][weight_name] = param_int.cpu()
408
+ layer_state_dict[layer_id][scale_name] = torch.tensor(scale, dtype=torch.int32)
409
+ print(f'layer {layer_id} shared_expert w3 type: {typ}, shape: {shape}, weight_name: {weight_name}, scale_name: {scale_name}')
410
+ else:
411
+ layer_state_dict[layer_id][name] = param
412
+ else:
413
+ layer_state_dict[layer_id][name] = param
414
+ elif len(ns) == 5:
415
+ if ns[1] == 'experts':
416
+ expert_id = int(ns[2])
417
+ if (ns[3] == 'w1' or ns[3] == 'w2' or ns[3] == 'w3') and ns[4] == 'scale':
418
+ continue
419
+ elif ns[3] == 'w1' and ns[4] == 'weight':
420
+ scale = experts_w1_rescales[layer_id][expert_id]
421
+ handle_expert_w(layer_id, expert_id, 1, param, name, scale, typ, shape, f'{save_path}/experts-{layer_id}')
422
+ elif ns[3] == 'w2' and ns[4] == 'weight':
423
+ scale = experts_w2_rescales[layer_id][expert_id]
424
+ handle_expert_w(layer_id, expert_id, 2, param, name, scale, typ, shape, f'{save_path}/experts-{layer_id}')
425
+ elif ns[3] == 'w3' and ns[4] == 'weight':
426
+ scale = experts_w3_rescales[layer_id][expert_id]
427
+ handle_expert_w(layer_id, expert_id, 3, param, name, scale, typ, shape, f'{save_path}/experts-{layer_id}')
428
+ else:
429
+ layer_state_dict[layer_id][name] = param
430
+ else:
431
+ layer_state_dict[layer_id][name] = param
432
+ elif ns[0] == 'attn':
433
+ if len(ns) == 3:
434
+ if ns[1] == 'wq_a' and ns[2] == 'scale':
435
+ continue
436
+ elif ns[1] == 'wq_a' and ns[2] == 'weight':
437
+ param_weight = param.cuda()
438
+ weight_name = name
439
+
440
+ scale_name = name.replace('weight', 'scale')
441
+ param_scale = states[scale_name]
442
+
443
+ weight = weight_dequant(param_weight, param_scale.cuda())
444
+
445
+ weight_int = (weight.to(torch.float32) * (2 ** 30)).round().to(torch.int32)
446
+
447
+ layer_state_dict[layer_id][weight_name] = weight_int.cpu()
448
+
449
+ print(f'layer {layer_id} wq_a weight, type: {typ}, shape: {shape}', flush=True)
450
+ elif ns[1] == 'q_norm':
451
+ print(f'layer {layer_id} q_norm, type: {typ}, shape: {shape}', flush=True)
452
+
453
+ param_int3 = (param.to(torch.float32) * (2 ** 19)).round().to(torch.int32)
454
+ layer_state_dict[layer_id][name] = param_int3
455
+ elif ns[1] == 'kv_norm':
456
+ print(f'layer {layer_id} kv_norm, type: {typ}, shape: {shape}', flush=True)
457
+
458
+ param_int4 = (param.to(torch.float32) * (2 ** 23)).round().to(torch.int32)
459
+ layer_state_dict[layer_id][name] = param_int4
460
+ elif ns[1] == 'wq_b' and ns[2] == 'scale':
461
+ continue
462
+ elif ns[1] == 'wq_b' and ns[2] == 'weight':
463
+ param_weight = param.cuda()
464
+ weight_name = name
465
+
466
+ scale_name = name.replace('weight', 'scale')
467
+ param_scale = states[scale_name]
468
+
469
+ weight = weight_dequant(param_weight, param_scale.cuda())
470
+
471
+ weight_int = (weight.to(torch.float32) * (2 ** 30)).round().to(torch.int32)
472
+
473
+ weight_int = weight_int.view(128, 192, 1536)
474
+ wq_b1, wq_b2 = torch.split(weight_int, [128, 64], dim=-2)
475
+
476
+ print(f'layer {layer_id} wq_b1 weight, shape: {wq_b1.shape}, wq_b2 weight, shape: {wq_b2.shape}', flush=True)
477
+
478
+ wq_b1 = wq_b1.reshape(128 * 128, 1536)
479
+ wq_b2 = wq_b2.reshape(128 * 64, 1536)
480
+ wq_b1_name = weight_name.replace('wq_b', 'wq_b1')
481
+ wq_b2_name = weight_name.replace('wq_b', 'wq_b2')
482
+
483
+ # layer_state_dict[layer_id][weight_name] = weight_int.cpu()
484
+ layer_state_dict[layer_id][wq_b1_name] = wq_b1.cpu()
485
+ layer_state_dict[layer_id][wq_b2_name] = wq_b2.cpu()
486
+
487
+ print(f'layer {layer_id} wq_b weight, type: {typ}, shape: {shape}', flush=True)
488
+ elif ns[1] == 'wkv_a' and ns[2] == 'scale':
489
+ continue
490
+ elif ns[1] == 'wkv_a' and ns[2] == 'weight':
491
+ param_weight = param.cuda()
492
+ weight_name = name
493
+
494
+ scale_name = name.replace('weight', 'scale')
495
+ param_scale = states[scale_name]
496
+
497
+ weight = weight_dequant(param_weight, param_scale.cuda())
498
+
499
+ weight_int = (weight.to(torch.float32) * (2 ** 29)).round().to(torch.int32)
500
+
501
+ # layer_state_dict[layer_id][weight_name] = weight_int.cpu()
502
+
503
+ weight_int = weight_int.view(576, 7168)
504
+ wkv_a1, wkv_a2 = torch.split(weight_int, [512, 64], dim=-2)
505
+
506
+ print(f'layer {layer_id} wkv_a1 weight, shape: {wkv_a1.shape}, wkv_a2 weight, shape: {wkv_a2.shape}', flush=True)
507
+
508
+ wkv_a1_name = weight_name.replace('wkv_a', 'wkv_a1')
509
+ wkv_a2_name = weight_name.replace('wkv_a', 'wkv_a2')
510
+
511
+ # layer_state_dict[layer_id][weight_name] = weight_int.cpu()
512
+ layer_state_dict[layer_id][wkv_a1_name] = wkv_a1.cpu()
513
+ layer_state_dict[layer_id][wkv_a2_name] = wkv_a2.cpu()
514
+
515
+ print(f'layer {layer_id} wkv_a weight, type: {typ}, shape: {shape}', flush=True)
516
+ elif ns[1] == 'wkv_b' and ns[2] == 'scale':
517
+ continue
518
+ elif ns[1] == 'wkv_b' and ns[2] == 'weight':
519
+ param_weight = param.cuda()
520
+ weight_name = name
521
+
522
+ scale_name = name.replace('weight', 'scale')
523
+ param_scale = states[scale_name]
524
+
525
+ weight = weight_dequant(param_weight, param_scale.cuda())
526
+
527
+ wkv_b = weight.view(128, 256, 512)
528
+
529
+ wkv_b_1 = wkv_b[:, :128]
530
+ wkv_b_1 = wkv_b_1.reshape(128 * 128, 512)
531
+ scale1 = wkv_b_1_rescales[layer_id]
532
+ wkv_b_1_rescale = 2 ** scale1
533
+ wkv_b_1_int = torch.round(wkv_b_1.to(torch.float32) * wkv_b_1_rescale).to(torch.int32)
534
+
535
+ wkv_b_2 = wkv_b[:, -128:]
536
+ wkv_b_2 = wkv_b_2.reshape(128 * 128, 512)
537
+ scale2 = wkv_b_2_rescales[layer_id]
538
+ wkv_b_2_rescale = 2 ** scale2
539
+ wkv_b_2_int = torch.round(wkv_b_2.to(torch.float32) * wkv_b_2_rescale).to(torch.int32)
540
+
541
+ wkv_b_1_name = weight_name.replace("wkv_b", "wkv_b_1")
542
+ wkv_b_1_scale_name = scale_name.replace("wkv_b", "wkv_b_1")
543
+ layer_state_dict[layer_id][wkv_b_1_name] = wkv_b_1_int.cpu()
544
+ layer_state_dict[layer_id][wkv_b_1_scale_name] = torch.tensor(scale1, dtype=torch.int32)
545
+
546
+ wkv_b_2_name = weight_name.replace("wkv_b", "wkv_b_2")
547
+ wkv_b_2_scale_name = scale_name.replace("wkv_b", "wkv_b_2")
548
+ layer_state_dict[layer_id][wkv_b_2_name] = wkv_b_2_int.cpu()
549
+ layer_state_dict[layer_id][wkv_b_2_scale_name] = torch.tensor(scale2, dtype=torch.int32)
550
+
551
+ print(f'layer {layer_id} wkv_b, type: {typ}, shape: {shape}, wkv_b_1 weight: {wkv_b_1_name}, wkv_b_1 scale: {wkv_b_1_scale_name}, wkv_b_2 weight: {wkv_b_2_name}, wkv_b_2 scale: {wkv_b_2_scale_name}', flush=True)
552
+ elif ns[1] == 'wo' and ns[2] == 'scale':
553
+ continue
554
+ elif ns[1] == 'wo' and ns[2] == 'weight':
555
+ param_weight = param.cuda()
556
+ weight_name = name
557
+
558
+ scale_name = name.replace('weight', 'scale')
559
+ param_scale = states[scale_name]
560
+
561
+ weight = weight_dequant(param_weight, param_scale.cuda())
562
+
563
+ scale = wo_rescales[layer_id]
564
+ rescale = 2 ** scale
565
+
566
+ if layer_id != 58:
567
+ param_int = (weight.to(torch.float32) * rescale).round().to(torch.int32)
568
+ else:
569
+ wo_abs = weight.abs().cpu()
570
+ maxpos = wo_abs.argmax()
571
+ row, col = divmod(maxpos.item(), weight.size(1))
572
+ print(f'maxpos: {maxpos}, {row} {col}', flush=True)
573
+
574
+ vstr = getBF16PrintStr(weight[row][col])
575
+ print(f'weight[{row}][{col}]: {vstr}', flush=True)
576
+ weight[row][col] = 0
577
+ param_int = (weight.to(torch.float32) * rescale).round().to(torch.int32)
578
+ param_int[row][col] = -(2 ** 31)
579
+
580
+ layer_state_dict[layer_id][weight_name] = param_int.cpu()
581
+ layer_state_dict[layer_id][scale_name] = torch.tensor(scale, dtype=torch.int32)
582
+
583
+ print(f'layer {layer_id} wo weight, type: {typ}, shape: {shape}, weight: {weight_name}, scale: {scale_name}', flush=True)
584
+ else:
585
+ layer_state_dict[layer_id][name] = param
586
+ else:
587
+ layer_state_dict[layer_id][name] = param
588
+ else:
589
+ layer_state_dict[layer_id][name] = param
590
+
591
+ save_file(layer_state_dict[layer_id], os.path.join(save_path, f"layer-{layer_id}.safetensors"))
592
+ print(f'Finish saving layer {layer_id}', flush=True)
593
+ layer_state_dict0[layer_id] = {}
594
+ layer_state_dict[layer_id] = {}
595
+
596
+ print('Finish opening')
597
+
598
+ os.makedirs(save_path, exist_ok=True)
599
+
600
+ print(layer_state_dict)
601
+ print(experts)
602
+
603
+ save_file(head_state_dict, os.path.join(save_path, f"head_int.safetensors"))
604
+ save_file(norm_state_dict, os.path.join(save_path, f"norm_int.safetensors"))
605
+ save_file(embed_state_dict, os.path.join(save_path, f"embed_int.safetensors"))
606
+ # for i, st in enumerate(layer_state_dict):
607
+ # # print(f'{i} {st['attn_norm.weight']}', flush=True)
608
+ # # print(f'{i} {st['ffn_norm.weight']}', flush=True)
609
+ # save_file(st, os.path.join(save_path, f"layer-{i}.safetensors"))
610
+ # print(f'Finish saving layer {i}', flush=True)
611
+
612
+ # for i in trange(mp):
613
+ # save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors"))
614
+
615
+ # print('Finish saving files')
616
+
617
+ for file_path in glob(os.path.join(hf_ckpt_path, "*token*")):
618
+ new_file_path = os.path.join(save_path, os.path.basename(file_path))
619
+ shutil.copyfile(file_path, new_file_path)
620
+
621
+
622
+ if __name__ == "__main__":
623
+ parser = ArgumentParser()
624
+ parser.add_argument("--hf-ckpt-path", type=str, required=True)
625
+ parser.add_argument("--save-path", type=str, required=True)
626
+ parser.add_argument("--n-experts", type=int, required=True)
627
+ parser.add_argument("--model-parallel", type=int, required=True)
628
+ args = parser.parse_args()
629
+ assert args.n_experts % args.model_parallel == 0, "Number of experts must be divisible by model parallelism"
630
+ main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel)
inference/generate.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import json
4
+ from argparse import ArgumentParser
5
+ from typing import List
6
+ from torch import nn
7
+
8
+ import torch
9
+ import torch.distributed as dist
10
+ from transformers import AutoTokenizer
11
+ from safetensors.torch import load_file, load_model
12
+
13
+ from model import Transformer, ModelArgs, Block
14
+ from concurrent.futures import ThreadPoolExecutor
15
+ from kernel import softmax_q21, softmax_q19
16
+
17
+ snark = False
18
+
19
+ model = None
20
+ kv_caches = [ torch.zeros(1, 4096 * 4, 512, dtype=torch.int64) ] * 61
21
+ pe_caches = [ torch.zeros(1, 4096 * 4, 64, dtype=torch.int64) ] * 61
22
+ state_dicts = [None] * 61
23
+
24
+ def getF32PrintStr(ele):
25
+ v = int(ele.cpu().view(torch.uint32).item())
26
+ ex = str((v >> 23 & 0xFF) - 127)
27
+ r = '(1+' + str(v & 0x7FFFFF) + '/8388608)'
28
+ if v & 0x80000000:
29
+ vstr = '-' + r + '*2^' + ex
30
+ else:
31
+ vstr = r + '*2^' + ex
32
+ return vstr
33
+
34
+ def getBF16PrintStr(ele):
35
+ v = int(ele.cpu().view(torch.uint16).item())
36
+ ex = v >> 7 & 0xFF
37
+ r = '(1+' + str(v & 0x7F) + '/128)'
38
+ rraw = v & 0x7F
39
+
40
+ if v & 0x8000:
41
+ vstr = '-' + r + '*2^' + str(ex - 127)
42
+ else:
43
+ vstr = r + '*2^' + str(ex - 127)
44
+ return vstr
45
+
46
+ def mem(i):
47
+ a = torch.cuda.memory_allocated()/1024**2
48
+ r = torch.cuda.memory_reserved()/1024**2
49
+ m = torch.cuda.max_memory_allocated()/1024**2
50
+ print(f"{i} allocated={a:.1f}MB, reserved={r:.1f}MB, max={m:.1f}MB", flush=True)
51
+
52
+ def load_model2(ckpt_path):
53
+ global model
54
+
55
+ with torch.device("cuda"):
56
+ load_model(model.embed, os.path.join(ckpt_path, f"embed_int.safetensors"))
57
+ load_model(model.norm, os.path.join(ckpt_path, f"norm_int.safetensors"))
58
+ load_model(model.head, os.path.join(ckpt_path, f"head_int.safetensors"))
59
+
60
+ # logits 的 scale = 2^21
61
+ def sample(logits, temperature: float = 1.0):
62
+ """
63
+ Samples a token from the logits using temperature scaling.
64
+
65
+ Args:
66
+ logits (torch.Tensor): The logits tensor for token predictions.
67
+ temperature (float, optional): Temperature for scaling logits. Defaults to 1.0.
68
+
69
+ Returns:
70
+ torch.Tensor: The sampled token.
71
+ """
72
+ # logits = logits.to(torch.float32) * (2 ** -15)
73
+ # typ = logits.type()
74
+ # print(f'sample logits type: {typ}')
75
+ # logits = logits / max(temperature, 1e-5)
76
+ # probs = torch.softmax(logits, dim=-1)
77
+ # return probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)
78
+
79
+ sample_open = False
80
+
81
+ if sample_open:
82
+ maxx = logits.abs().max()
83
+ typ = logits.type()
84
+ print(f'sample logits type: {typ}, shape: {logits.shape}, abs max: {maxx}')
85
+ if temperature > 1e-5:
86
+ temp_int = int(temperature)
87
+ # logits = (logits + temp_int // 2) // temp_int
88
+ logits = logits // temp_int
89
+ print(f'temp_int: {temp_int}', flush=True)
90
+ else:
91
+ logits = logits * (10 ** 5)
92
+ # print(f'sample 22 logits type: {typ}, shape: {logits.shape}, logits: {logits}')
93
+ # probs = torch.softmax(logits, dim=-1)
94
+
95
+ logits = logits.unsqueeze(2)
96
+
97
+ max0 = logits.abs().max()
98
+ print(f'sample 2233 logits shape: {logits.shape}, abs max0: {max0}')
99
+
100
+ # probs 的 rescale 为 2^21
101
+ probs = torch.empty_like(logits, dtype=torch.int64, device='cuda')
102
+ softmax_q21(logits.contiguous(), probs)
103
+
104
+ probs = probs.squeeze(2)
105
+
106
+ # print(f'sample 2233 probs shape: {probs.shape}')
107
+
108
+ typ2 = probs.type()
109
+ max1 = probs.abs().max()
110
+ print(f'sample 33 probs type: {typ2}, shape: {probs.shape}, probs abs max: {max1}', flush=True)
111
+
112
+ rand = torch.empty_like(probs, dtype=torch.float32, device='cuda').exponential_(1)
113
+ rand_abs = rand.abs()
114
+ rmin = getF32PrintStr(rand_abs.min())
115
+ rmax = getF32PrintStr(rand_abs.max())
116
+ print(f'sample 333 rand abs min: {rmin}, max: {rmax}', flush=True)
117
+
118
+ # rand = (rand * (2 ** 21)).round().to(torch.int64) + (2 ** 15)
119
+ rand = (rand * (2 ** 10)).round().to(torch.int64) + (2 ** 4)
120
+ max2 = rand.abs().max()
121
+ min2 = rand.abs().min()
122
+ print(f'sample 55 rand abs min: {min2}, max: {max2}', flush=True)
123
+
124
+ # probs 的 rescale 为 2^21
125
+ # probs = (probs * (2 ** 21) + rand // 2) // rand
126
+ probs = (probs * (2 ** 10)) // rand
127
+
128
+ max3 = probs.abs().max()
129
+ print(f'sample 66 probs abs max: {max3}', flush=True)
130
+
131
+ res = probs.argmax(dim=-1)
132
+ tid = res[0][0].item()
133
+ tv = probs[0][0][tid]
134
+ randv = rand[0][0][tid]
135
+ # typ3 = res.type()
136
+ print(f'sample 44 res: {res}, tid: {tid}, tv: {tv}, randv: {randv}')
137
+ else:
138
+ probs = logits.unsqueeze(2)
139
+ max3 = probs.abs().max()
140
+ print(f'sample 66 probs abs max: {max3}', flush=True)
141
+
142
+ res = probs.argmax(dim=-1)
143
+ return res
144
+
145
+ def saveTensor(fileName, t):
146
+ with open(fileName, "w", encoding="utf-8") as f:
147
+ t = t.detach()
148
+ if t.device.type != "cpu":
149
+ t = t.cpu()
150
+ t = t.contiguous()
151
+ with open(fileName, "wb") as f:
152
+ # .numpy() -> bytes(C-order)
153
+ f.write(t.numpy().tobytes(order="C"))
154
+
155
+ # model:用来输出最终结果token的模型。这里导入的是deepseek的模型架构。
156
+ # prompt_tokens: 即前文中的prompt_tokens, shape为 (batch_size, total_len, 7168)
157
+ # max_new_tokens:允许生成的最大的tokens的数量。生成过程会在这个数量或遇到终止标识符 (eos_id) 时停止。
158
+ # eos_id:<end▁of▁sentence>这个token对应的ID。当生成结果中遇到这个 token 时,该序列的生成会停止。
159
+ # temperature:采样温度。温度值控制生成时的随机性:温度越高,采样的随机性越大;当温度为 0 时,使用贪心策略(即选取概率最高的 token)。
160
+ # prompt的输入是List[List[int]],外面的那个List是batch,里面的这个List是seq。等效于输入进去的就是已经tokenize好了的batch个的prompt。在我们这个“Who are you?”的示例中,batch = 1
161
+
162
+ @torch.inference_mode()
163
+ def generate(
164
+ # model: Transformer,
165
+ ckpt_path: str,
166
+ args: ModelArgs,
167
+ tokenizer: AutoTokenizer,
168
+ prompt_tokens: List[List[int]],
169
+ max_new_tokens: int,
170
+ eos_id: int,
171
+ temperature: float = 1.0
172
+ ) -> List[List[int]]:
173
+ """
174
+ Generates new tokens based on the given prompt tokens using the specified model.
175
+
176
+ Args:
177
+ model (Transformer): The transformer model used for token generation.
178
+ prompt_tokens (List[List[int]]): A list of lists containing the prompt tokens for each sequence.
179
+ max_new_tokens (int): The maximum number of new tokens to generate.
180
+ eos_id (int): The end-of-sequence token ID.
181
+ temperature (float, optional): The temperature value for sampling. Defaults to 1.0.
182
+
183
+ Returns:
184
+ List[List[int]]: A list of lists containing the generated tokens for each sequence.
185
+ """
186
+
187
+ global model, layers
188
+ global kv_caches, pe_caches
189
+
190
+ prompt_lens = [len(t) for t in prompt_tokens]
191
+ assert max(prompt_lens) <= args.max_seq_len, f"Prompt length exceeds model maximum sequence length (max_seq_len={args.max_seq_len})"
192
+ total_len = min(args.max_seq_len, max_new_tokens + max(prompt_lens))
193
+ # 利用 torch.full 初始化一个形状为 (batch_size, total_len) 的张量,所有值初始为 -1。这里 -1 作为“未填充token”的标志
194
+ # torch.long 64 位 bit
195
+ tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda")
196
+ # 遍历每个 prompt,将其 token 填入对应行的前面部分。这样,张量中前面部分对应的是已知的 prompt,后面部分为待生成的 token 空间。
197
+ for i, t in enumerate(prompt_tokens):
198
+ tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
199
+
200
+ beginstr = tokenizer.decode(tokens[0][0:prompt_lens[0]], skip_special_tokens=True)
201
+ # torch.cuda.synchronize()
202
+ print(' ++++++ token:', beginstr, flush=True)
203
+
204
+ prev_pos = 0
205
+ # finished则是一个布尔张量,标记每个序列是否已经完成生成。初始时假设所有序列均未完成(False)
206
+ finished = torch.tensor([False] * len(prompt_tokens), device="cuda")
207
+ # prompt_mask则生成一个掩码张量,用来标记哪些位置已经有prompt token(即 token 不等于 -1)。在生成过程中,这个掩码帮助区分哪些位置是用户提供的prompt,哪些是模型生成的token。
208
+ # 这是用来辅助自回归的生成的,避免prompt_tokens的部分被覆盖。
209
+ prompt_mask = tokens != -1
210
+
211
+ # cur_pos则记录prompt_tokens里最短的那段prompt的长度,后续的生成就从这个位置开始,以确保所有的输入都能得到生成正确而完整的回答。
212
+ for cur_pos in range(min(prompt_lens), total_len):
213
+ print(f'prev_pos: {prev_pos}, cur_pos: {cur_pos}, total_len: {total_len}', flush=True)
214
+ t = tokenizer.decode(tokens[0][prev_pos:cur_pos], skip_special_tokens=True)
215
+ print(str(cur_pos) + ' ---------- token list: ' + str(tokens[0][prev_pos:cur_pos].tolist()), flush=True)
216
+
217
+ if snark:
218
+ os.makedirs(f'zkdata/pos_{prev_pos}', exist_ok=True)
219
+ saveTensor(f'zkdata/pos_{prev_pos}/tokens.bin', tokens[0][prev_pos:cur_pos].cpu())
220
+
221
+ # logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
222
+
223
+ h, start_pos, seqlen = model.prep_inference(tokens[:, prev_pos:cur_pos], prev_pos)
224
+ print('h 1 shape: ' + str(h.shape), flush=True)
225
+
226
+ for i in range(args.n_layers):
227
+ print(f'begin layer {i} -----------------', flush=True)
228
+ with torch.device("cuda"):
229
+
230
+ with torch.no_grad():
231
+ if hasattr(model.layers[i], 'attn_norm'):
232
+ del model.layers[i].attn_norm.weight
233
+
234
+ model.layers[i] = Block(i, args, ckpt_path)
235
+ model.layers[i].load_state_dict(state_dicts[i], False)
236
+ model.layers[i].attn.kv_cache = kv_caches[i].to('cuda')
237
+ model.layers[i].attn.pe_cache = pe_caches[i].to('cuda')
238
+
239
+ h = model.layer_inference(i, h, start_pos, seqlen)
240
+
241
+ kv_caches[i] = model.layers[i].attn.kv_cache
242
+ pe_caches[i] = model.layers[i].attn.pe_cache
243
+ model.layers[i] = nn.Module()
244
+
245
+ tmph = model.norm(h)[0][:, -1]
246
+
247
+ tmph_abs = tmph.abs()
248
+ tmph_min = tmph_abs.min()
249
+ tmph_max = tmph_abs.max()
250
+ print(f'tmph_abs min: {tmph_min}, max: {tmph_max}', flush=True)
251
+
252
+ tmplogits = model.head(tmph[None, :])
253
+
254
+ tmp_next_token = tmplogits.argmax(dim=-1)
255
+ tid = tmp_next_token[0][0].item()
256
+ tmp_logit = tmplogits[0][0][tid]
257
+ tmp_completion = tokenizer.decode([tmp_next_token[0][0]], skip_special_tokens=True)
258
+ print(f'position {cur_pos} tid: {tid}, tmp_logit:{tmp_logit}, candidate: {tmp_completion}', flush=True)
259
+
260
+ # logits 的 scale = 2^21
261
+ logits = model.finish_inference(h)
262
+
263
+ if temperature > 0:
264
+ next_token = sample(logits, temperature)
265
+ else:
266
+ next_token = logits.argmax(dim=-1)
267
+ next_token = torch.where(prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token)
268
+ # print('next_token shape: ' + str(next_token.shape))
269
+ tokens[:, cur_pos] = next_token
270
+ # 当所有finished里对应每一行的值都变成true的时候就意味着生成结束了。之后再进行decode,就得到了最终的输出。
271
+ finished |= torch.logical_and(~prompt_mask[:, cur_pos], next_token.view(-1) == eos_id)
272
+ prev_pos = cur_pos
273
+
274
+ completion = tokenizer.decode(tokens[0][0:cur_pos+1], skip_special_tokens=True)
275
+ print(f'---------- Result: position {cur_pos}, token: {completion}', flush=True)
276
+
277
+ if finished.all():
278
+ break
279
+ completion_tokens = []
280
+ for i, toks in enumerate(tokens.tolist()):
281
+ toks = toks[prompt_lens[i]:prompt_lens[i]+max_new_tokens]
282
+ if eos_id in toks:
283
+ toks = toks[:toks.index(eos_id)]
284
+ completion_tokens.append(toks)
285
+ return completion_tokens
286
+
287
+
288
+ def main(
289
+ ckpt_path: str,
290
+ config: str,
291
+ input_file: str = "",
292
+ interactive: bool = True,
293
+ max_new_tokens: int = 100,
294
+ temperature: float = 1.0,
295
+ ) -> None:
296
+ """
297
+ Main function to load the model and perform interactive or batch text generation.
298
+
299
+ Args:
300
+ ckpt_path (str): Path to the model checkpoint directory.
301
+ config (str): Path to the model configuration file.
302
+ input_file (str, optional): Path to a file containing input prompts. Defaults to "".
303
+ interactive (bool, optional): Whether to run in interactive mode. Defaults to True.
304
+ max_new_tokens (int, optional): Maximum number of new tokens to generate. Defaults to 100.
305
+ temperature (float, optional): Temperature for sampling. Defaults to 1.0.
306
+ """
307
+ global model
308
+
309
+ # WORLD_SIZE描述了全局进程总数(即参与训练的 GPU 总数)
310
+ world_size = int(os.getenv("WORLD_SIZE", "1"))
311
+ # RANK则是当前进程的全局编号(即多机多卡上的进程编号,范围是[0,world_size-1])
312
+ rank = int(os.getenv("RANK", "0"))
313
+ # LOCAL_RANK则是当前节点(机器)上的进程编号(即目前机器上的编号)
314
+ # local_rank = int(os.getenv("LOCAL_RANK", "0"))
315
+ print('WORLD_SIZE: ' + str(world_size) + ', rank: ' + str(rank))
316
+ # 当world_size>1时,则表示当前是多机多卡训练,就需要初始化分布式进程组了。这个时候就使用NCCL后端来初始化分布式训练。
317
+ # 这里初始化了进程组,因此在后续的加载参数中,每个进程将通过仅加载属于自己进程部分的模型参数来全量加载模型。
318
+ # NCCL(NVIDIA Collective Communications Library)是 NVIDIA 提供的一个用于高效多 GPU 和多节点通信的库。
319
+ # 它专为深度学习和高性能计算(HPC)设计,能够显著加速分布式训练和多 GPU 计算任务。
320
+ # if world_size > 1:
321
+ # dist.init_process_group("nccl")
322
+ # global print
323
+ # 屏蔽非主进程的print函数,防止多个进程同时打印日志,保持输出整洁
324
+ # if rank != 0:
325
+ # print = lambda *_, **__: None
326
+ # 设定GPU设备,让当前进程只使用local_rank对应的GPU:
327
+ # torch.cuda.set_device(local_rank)
328
+ torch.cuda.set_device(0)
329
+ # 设置pytorch计算时的默认数据类型。这里使用的是BF16
330
+ torch.set_default_dtype(torch.bfloat16)
331
+ # 限制 PyTorch 在 CPU 计算时最多使用 8 个线程,防止过多线程竞争资源:
332
+ torch.set_num_threads(8)
333
+ # 设定随机种子,保证不同进程初始化时随机数相同。
334
+ torch.manual_seed(965)
335
+ with open(config) as f:
336
+ args = ModelArgs(**json.load(f))
337
+ print(args)
338
+ # 首先根据deepseek给定的tokenizer.json加载了tokenizer,然后通过load_model加载了参数。注意:一般来讲,load_model是只能加载单一的safetensors的。
339
+ # 但由于之前我们通过dist.init_process_group("nccl")完成了进程组的初始化,因此这一行代码每个进程都会执行。又因为确定好了rank ,
340
+ # 进而使得每个进程只会加载属于自己那部分的模型。到此便完成了模型的全量加载。
341
+
342
+ for i in range(args.n_layers):
343
+ modelPath = os.path.join(ckpt_path, f"layer-{i}.safetensors")
344
+ state_dicts[i] = load_file(modelPath, device="cpu")
345
+
346
+ with torch.device("cuda"):
347
+ model = Transformer(args)
348
+
349
+ tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
350
+ load_model2(ckpt_path)
351
+
352
+ # with torch.device("cuda"):
353
+ # freqs_cis_orig = precompute_freqs_cis(args)
354
+ # load_model2(ckpt_path)
355
+
356
+ # tokenizer.encode 将字符编码转换为 token, tokenizer.decode 转换为字符编码
357
+ # generate 函数将一直生成下一个字符,直到遇到结束字符为止
358
+ # tokenizer.decode(generate(model, [tokenizer.encode("DeepSeek")], 200, -1, 1.)[0])
359
+ # cmp1 = tokenizer.decode(generate(ckpt_path, args, tokenizer, [tokenizer.encode("DeepSeek")], 2, -1, 1.)[0])
360
+ # print(' ---------- DeepSeek result: ' + str(cmp1), flush=True)
361
+ # print('begin to load model: ' + f"model{rank}-mp{world_size}.safetensors")
362
+
363
+ if rank == 0:
364
+ # !!! 这一块代码会导致显存泄露
365
+ embed_abs = model.embed.weight.detach().cpu().abs()
366
+ abs_min = torch.min(embed_abs)
367
+ abs_max = torch.max(embed_abs)
368
+ print('embed abs_min: ' + str(abs_min), flush=True)
369
+ print('embed abs_max: ' + str(abs_max), flush=True)
370
+ else:
371
+ pass
372
+
373
+ if interactive:
374
+ messages = []
375
+ while True:
376
+ if world_size == 1:
377
+ prompt = input(">>> ")
378
+ # 当多机多卡(world_size>1)并且只有主进程(rank==0)接受用户的输入prompt,并通过dist.broadcast_object_list(objects,0)的方式广播给其他进程(rank!=0)。
379
+ # 其他进程通过dist.broadcast_object_list(objects,0)接受主进程的prompt,并用于后续进入模型之中的输入。
380
+ # 主进程在input()处会阻塞,而非主进程将在广播这一步阻塞。因此在接受到输入之后,可以保证所有进程接收到相同的prompt。
381
+ elif rank == 0:
382
+ prompt = input(">>> ")
383
+ objects = [prompt]
384
+ dist.broadcast_object_list(objects, 0)
385
+ else:
386
+ objects = [None]
387
+ dist.broadcast_object_list(objects, 0)
388
+ prompt = objects[0]
389
+ if prompt == "/exit":
390
+ break
391
+ elif prompt == "/clear":
392
+ messages.clear()
393
+ continue
394
+ # 假设我们的prompt是“Hello,Who are you?”则其输入会整理成如下的chat template:
395
+ #[
396
+ #{
397
+ #"role":"user",
398
+ #"content":"Hello,Who are you?"
399
+ #}
400
+ #]
401
+ messages.append({"role": "user", "content": prompt})
402
+ # 而后经过tokenizer.apply_chat_template 将输入的chat template转化为模型训练时所使用的真正输入的token。
403
+ # 可以看huggingface关于chat template的官方文档,这里面介绍得十分详细。在这里,我们只需要知道这个chat template转化为了模型输入的token即可。
404
+ # tokenizer.apply_chat_template的tokenize参数是默认为正的。因此,经过了转化后的聊天模板将会变成int型的token形式。
405
+ # 也就是说,上面的chat template 最终将变为List[int],如[134,135,1617,...,124]等,之后作为input tokens输入到模型中。
406
+ prompt_tokens = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
407
+ # prompt_tokens = tokenizer(prompt, add_special_tokens=True)
408
+ # 我们现在的prompt已经变成了prompt_tokens,并通过generate()变成了 输出的回答所对应的token(completion_tokens),
409
+ # 而后再decode成为完整的回答后重新组成chat template并加入到历史的message中,则一个流程的问答就结束了。
410
+ # completion_tokens = generate(model, [prompt_tokens], max_new_tokens, tokenizer.eos_token_id, temperature)
411
+ # with torch.no_grad():
412
+ completion_tokens = generate(ckpt_path, args, tokenizer, [prompt_tokens], max_new_tokens, tokenizer.eos_token_id, temperature)
413
+ # completion_tokens = generate(ckpt_path, args, tokenizer, [prompt_tokens['input_ids']], max_new_tokens, tokenizer.eos_token_id, temperature)
414
+ completion = tokenizer.decode(completion_tokens[0], skip_special_tokens=True)
415
+ print(completion)
416
+ messages.append({"role": "assistant", "content": completion})
417
+ else:
418
+ with open(input_file) as f:
419
+ prompts = [line.strip() for line in f.readlines()]
420
+ assert len(prompts) <= args.max_batch_size, f"Number of prompts exceeds maximum batch size ({args.max_batch_size})"
421
+ prompt_tokens = [tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True) for prompt in prompts]
422
+ # completion_tokens = generate(model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature)
423
+ completion_tokens = generate(ckpt_path, args, tokenizer, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature)
424
+ completions = tokenizer.batch_decode(completion_tokens, skip_special_tokens=True)
425
+ for prompt, completion in zip(prompts, completions):
426
+ print("Prompt:", prompt)
427
+ print("Completion:", completion)
428
+ print()
429
+
430
+ if world_size > 1:
431
+ dist.destroy_process_group()
432
+
433
+
434
+ if __name__ == "__main__":
435
+ """
436
+ Command-line interface for distributed text generation.
437
+
438
+ Arguments:
439
+ --ckpt-path (str): Path to the model checkpoint directory. 模型参数存放的路径。
440
+ --config (str): Path to the model configuration file. 模型的超参配置文件的路径。
441
+ --input-file (str, optional): File containing prompts for batch processing. 假设我们是批量输入prompt,则该参数是批量输入prompt的文件的路径。
442
+ --interactive (bool, optional): Enable interactive mode for generating text. 是否是问答交互式?这里相当于开启模型的“问答”模式。bool变量。
443
+ --max-new-tokens (int, optional): Maximum number of new tokens to generate. Defaults to 200. 限制要求生成的tokens的数量。
444
+ --temperature (float, optional): Temperature for sampling. Defaults to 0.2. 采样温度。
445
+
446
+ Raises:
447
+ AssertionError: If neither input-file nor interactive mode is specified.
448
+ """
449
+ parser = ArgumentParser()
450
+ parser.add_argument("--ckpt-path", type=str, required=True)
451
+ parser.add_argument("--config", type=str, required=True)
452
+ parser.add_argument("--input-file", type=str, default="")
453
+ parser.add_argument("--interactive", action="store_true")
454
+ parser.add_argument("--max-new-tokens", type=int, default=200)
455
+ parser.add_argument("--temperature", type=float, default=0.2)
456
+ args = parser.parse_args()
457
+ assert args.input_file or args.interactive, "Either input-file or interactive mode must be specified"
458
+ main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature)
inference/int64_gemm.cu ADDED
@@ -0,0 +1,1030 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // int64_gemm.cu
2
+ #include <cuda_runtime.h>
3
+ #include <stdint.h>
4
+ #include <stdio.h>
5
+
6
+ extern "C" __global__ void int64_32_bmm_broadcast_kernel(
7
+ const int64_t* __restrict__ A, // (B, M, K)
8
+ const int32_t* __restrict__ B, // (N, K)
9
+ int64_t* __restrict__ C, // (B, M, N)
10
+ int64_t* __restrict__ R, // remainer (B, M, N)
11
+ const int64_t a_rescale,
12
+ const int64_t b_rescale,
13
+ const int64_t c_rescale,
14
+ int Bdim, int M, int K, int N)
15
+ {
16
+ int b = blockIdx.z; // batch
17
+ int row = blockIdx.y * blockDim.y + threadIdx.y; // M
18
+ int col = blockIdx.x * blockDim.x + threadIdx.x; // N
19
+
20
+ if (row < M && col < N) {
21
+ __int128_t sum = 0;
22
+ __int128_t rescale = (1 << c_rescale) - 1;
23
+ for (int k = 0; k < K; ++k) {
24
+ int64_t a_val = A[b * M * K + row * K + k]; // A[b, row, k]
25
+ int32_t b_val = B[col * K + k]; // B[col, k]
26
+ sum += __int128_t(a_val / a_rescale) * __int128_t(b_val / b_rescale);
27
+ }
28
+ int ind = b * M * N + row * N + col;
29
+ // C[ind] = sum / c_rescale; // C[b, row, col]
30
+ // R[ind] = sum % c_rescale; // R[b, row, col]
31
+ C[ind] = int64_t(sum >> c_rescale); // C[b, row, col]
32
+ R[ind] = int64_t(sum & rescale); // R[b, row, col]
33
+ }
34
+ }
35
+
36
+ extern "C" void int64_32_bmm_broadcast_launcher(
37
+ const int64_t* A, const int32_t* B, int64_t* C, int64_t* R,
38
+ const int64_t a_rescale, const int64_t b_rescale, const int64_t c_rescale,
39
+ int Bdim, int M, int K, int N)
40
+ {
41
+ dim3 threads(32, 32);
42
+ dim3 blocks((N + threads.x - 1) / threads.x,
43
+ (M + threads.y - 1) / threads.y,
44
+ Bdim);
45
+
46
+ int64_32_bmm_broadcast_kernel<<<blocks, threads>>>(A, B, C, R, a_rescale, b_rescale, c_rescale, Bdim, M, K, N);
47
+ }
48
+
49
+ extern "C" __global__ void int64_64_bmm_broadcast_kernel(
50
+ const int64_t* __restrict__ A, // (B, M, K)
51
+ const int64_t* __restrict__ B, // (N, K)
52
+ int64_t* __restrict__ C, // (B, M, N)
53
+ int64_t* __restrict__ R, // remainer (B, M, N)
54
+ const int64_t a_rescale,
55
+ const int64_t b_rescale,
56
+ const int64_t c_rescale,
57
+ int Bdim, int M, int K, int N)
58
+ {
59
+ int b = blockIdx.z; // batch
60
+ int row = blockIdx.y * blockDim.y + threadIdx.y; // M
61
+ int col = blockIdx.x * blockDim.x + threadIdx.x; // N
62
+
63
+ if (row < M && col < N) {
64
+ __int128_t sum = 0;
65
+ __int128_t rescale = (1 << c_rescale) - 1;
66
+ for (int k = 0; k < K; ++k) {
67
+ int64_t a_val = A[b * M * K + row * K + k]; // A[b, row, k]
68
+ int64_t b_val = B[col * K + k]; // B[col, k]
69
+ sum += __int128_t(a_val / a_rescale) * __int128_t(b_val / b_rescale);
70
+ }
71
+ int ind = b * M * N + row * N + col;
72
+ // C[ind] = sum / c_rescale; // C[b, row, col]
73
+ // R[ind] = sum % c_rescale; // R[b, row, col]
74
+ C[ind] = int64_t(sum >> c_rescale); // C[b, row, col]
75
+ R[ind] = int64_t(sum & rescale); // R[b, row, col]
76
+ }
77
+ }
78
+
79
+ extern "C" void int64_64_bmm_broadcast_launcher(
80
+ const int64_t* A, const int64_t* B, int64_t* C, int64_t* R,
81
+ const int64_t a_rescale, const int64_t b_rescale, const int64_t c_rescale,
82
+ int Bdim, int M, int K, int N)
83
+ {
84
+ dim3 threads(32, 32);
85
+ dim3 blocks((N + threads.x - 1) / threads.x,
86
+ (M + threads.y - 1) / threads.y,
87
+ Bdim);
88
+
89
+ int64_64_bmm_broadcast_kernel<<<blocks, threads>>>(A, B, C, R, a_rescale, b_rescale, c_rescale, Bdim, M, K, N);
90
+ }
91
+
92
+ extern "C" __global__ void bf16_to_int32_2d_kernel(const uint16_t* input, int32_t* output, int rows, int cols, int rescale)
93
+ {
94
+ int row = blockIdx.y * blockDim.y + threadIdx.y;
95
+ int col = blockIdx.x * blockDim.x + threadIdx.x;
96
+
97
+ if (row < rows && col < cols) {
98
+ int idx = row * cols + col;
99
+
100
+ int v0 = input[idx];
101
+ int ex0 = ((v0 >> 7) & 0xFF) - 127;
102
+ int r0 = v0 & 0x7F;
103
+
104
+ if (ex0 == -127 && r0 == 0) {
105
+ output[idx] = 0;
106
+ return;
107
+ }
108
+
109
+ int ex2 = ex0 + rescale;
110
+ int r2 = r0 + 128;
111
+ uint32_t v = 0;
112
+ if(ex2 >= 0) {
113
+ v = r2 * (1 << ex2);
114
+ } else {
115
+ v = r2 / (1 << -ex2);
116
+ }
117
+
118
+ if (v0 & 0x8000) {
119
+ v = -v;
120
+ }
121
+
122
+ output[idx] = v;
123
+ }
124
+ }
125
+
126
+ extern "C" void bf16_to_int32_2d(const uint16_t* input, int32_t* output, int rows, int cols, int rescale) {
127
+ dim3 threads(32, 32);
128
+ dim3 blocks((cols + threads.x - 1) / threads.x,
129
+ (rows + threads.y - 1) / threads.y);
130
+
131
+ bf16_to_int32_2d_kernel<<<blocks, threads>>>(input, output, rows, cols, rescale);
132
+ }
133
+
134
+ extern "C" __global__ void wkv_b_bf16_to_int32_kernel(const uint16_t* input, int32_t* output, int rows, int cols)
135
+ {
136
+ int row = blockIdx.y * blockDim.y + threadIdx.y;
137
+ int col = blockIdx.x * blockDim.x + threadIdx.x;
138
+
139
+ if (row < rows && col < cols) {
140
+ int idx = row * cols + col;
141
+
142
+ int v0 = input[idx];
143
+ int ex0 = ((v0 >> 7) & 0xFF) - 127;
144
+ int r0 = v0 & 0x7F;
145
+
146
+ if (ex0 == -127 && r0 == 0) {
147
+ output[idx] = 0;
148
+ return;
149
+ }
150
+
151
+ if (ex0 >= -1) {
152
+ output[idx] = 0x7FFFFFFF;
153
+ return;
154
+ }
155
+
156
+ int ex2 = ex0 + 25;
157
+ int r2 = r0 + 128;
158
+ uint32_t v = 0;
159
+ if(ex2 >= 0) {
160
+ v = r2 * (1 << ex2);
161
+ } else {
162
+ v = r2 / (1 << -ex2);
163
+ }
164
+
165
+ if (v0 & 0x8000) {
166
+ v = -v;
167
+ }
168
+
169
+ output[idx] = v;
170
+ }
171
+ }
172
+
173
+ extern "C" void wkv_b_bf16_to_int32(const uint16_t* input, int32_t* output, int rows, int cols) {
174
+ dim3 threads(32, 32);
175
+ dim3 blocks((cols + threads.x - 1) / threads.x,
176
+ (rows + threads.y - 1) / threads.y);
177
+
178
+ wkv_b_bf16_to_int32_kernel<<<blocks, threads>>>(input, output, rows, cols);
179
+ }
180
+
181
+ extern "C" __global__ void float32_to_int64_2d_kernel(const uint32_t* input, int64_t* output, int rows, int cols, int rescale)
182
+ {
183
+ int row = blockIdx.y * blockDim.y + threadIdx.y;
184
+ int col = blockIdx.x * blockDim.x + threadIdx.x;
185
+
186
+ if (row < rows && col < cols) {
187
+ int idx = row * cols + col;
188
+
189
+ uint32_t v0 = input[idx];
190
+ int ex0 = ((v0 >> 23) & 0xFF) - 127;
191
+ int r0 = v0 & 0x7FFFFF;
192
+
193
+ if (ex0 == -127 && r0 == 0) {
194
+ output[idx] = 0;
195
+ return;
196
+ }
197
+
198
+ int ex2 = ex0 + rescale;
199
+ int64_t r2 = r0 + 8388608;
200
+ int64_t v = 0;
201
+ if(ex2 >= 0) {
202
+ v = r2 * (1 << ex2);
203
+ } else {
204
+ v = r2 / (1 << -ex2);
205
+ }
206
+
207
+ if (v0 & 0x80000000) {
208
+ v = -v;
209
+ }
210
+
211
+ output[idx] = v;
212
+ }
213
+ }
214
+
215
+ extern "C" void float32_to_int64_2d(const uint32_t* input, int64_t* output, int rows, int cols, int rescale) {
216
+ dim3 threads(32, 32);
217
+ dim3 blocks((cols + threads.x - 1) / threads.x,
218
+ (rows + threads.y - 1) / threads.y);
219
+
220
+ float32_to_int64_2d_kernel<<<blocks, threads>>>(input, output, rows, cols, rescale);
221
+ }
222
+
223
+ extern "C" __global__ void complex_int64_mul_kernel(
224
+ const int64_t* __restrict__ A,
225
+ const int64_t* __restrict__ B,
226
+ int64_t* __restrict__ C,
227
+ // int64_t high_rescale, int64_t row_rescale,
228
+ int batchSize, int seqLen, int headCount, int headDim)
229
+ {
230
+ int idx = blockIdx.x * blockDim.x + threadIdx.x;
231
+ int total = batchSize * seqLen * headCount * headDim;
232
+ if (idx >= total) return;
233
+
234
+ // 计算 A 的索引
235
+ int i = idx;
236
+ int dimId = i % headDim; i /= headDim;
237
+ int headId = i % headCount; i /= headCount;
238
+ int seqId = i % seqLen; i /= seqLen;
239
+ int batchId = i;
240
+
241
+ // A 索引
242
+ int a_idx = ((batchId * seqLen + seqId) * headCount + headId) * headDim + dimId;
243
+
244
+ // B 索引 (广播)
245
+ int b_idx = ((0 * seqLen + seqId) * 1 + 0) * headDim + dimId;
246
+
247
+ int64_t a0 = A[2 * a_idx];
248
+ int64_t a1 = A[2 * a_idx + 1];
249
+ int64_t b0 = B[2 * b_idx];
250
+ int64_t b1 = B[2 * b_idx + 1];
251
+
252
+ // C[2 * a_idx] = (a0 * b0 - a1 * b1) / c_resacle;
253
+ // C[2 * a_idx + 1] = (a0 * b1 + a1 * b0) / c_resacle;
254
+
255
+ // C[2 * a_idx] = __mul64hi(a0, b0) * high_rescale + a0 * b0 / row_rescale) - (__mul64hi(a1, b1) * high_rescale + a1 * b1 / row_rescale);
256
+ // C[2 * a_idx + 1] = (__mul64hi(a0, b1) * high_rescale + a0 * b1 / row_rescale) + (__mul64hi(a1, b0) * high_rescale + a1 * b0 / row_rescale);
257
+ int64_t a0b0 = ((__mul64hi(a0, b0) & 0x3FFFFFFFFFF) << 22) | (((a0 * b0) >> 42) & 0x3FFFFF);
258
+ int64_t a1b1 = ((__mul64hi(a1, b1) & 0x3FFFFFFFFFF) << 22) | (((a1 * b1) >> 42) & 0x3FFFFF);
259
+ int64_t a0b1 = ((__mul64hi(a0, b1) & 0x3FFFFFFFFFF) << 22) | (((a0 * b1) >> 42) & 0x3FFFFF);
260
+ int64_t a1b0 = ((__mul64hi(a1, b0) & 0x3FFFFFFFFFF) << 22) | (((a1 * b0) >> 42) & 0x3FFFFF);
261
+
262
+ C[2 * a_idx] = a0b0 - a1b1;
263
+ C[2 * a_idx + 1] = a0b1 + a1b0;
264
+
265
+ // if(idx == 32) {
266
+ // printf("%d %d %d, %d %d %d %d (%d %d %d %d): (%ld, %ld i) * (%ld, %ld i) = (%ld, %ld i)\n",
267
+ // idx, a_idx, b_idx,
268
+ // batchSize, seqLen, headCount, headDim,
269
+ // batchId, seqId, headId, dimId,
270
+ // a0, a1, b0, b1, C[2 * a_idx], C[2 * a_idx + 1]);
271
+ // }
272
+ }
273
+
274
+ extern "C" void complex_int64_mul(
275
+ const int64_t* A, const int64_t* B, int64_t* C,
276
+ // const int64_t high_rescale, const int64_t row_rescale,
277
+ int batchSize, int seqLen, int headCount, int headDim)
278
+ {
279
+ int total = batchSize * seqLen * headCount * headDim;
280
+ int threads = 256;
281
+ int blocks = (total + threads - 1) / threads;
282
+
283
+ complex_int64_mul_kernel<<<blocks, threads>>>(A, B, C,
284
+ // high_rescale, row_rescale,
285
+ batchSize, seqLen, headCount, headDim);
286
+ }
287
+
288
+
289
+ extern "C" __global__ void rms_norm_kernel_32(
290
+ const int64_t* __restrict__ A,
291
+ const int32_t* __restrict__ W,
292
+ const int64_t* __restrict__ rms,
293
+ int64_t* __restrict__ C,
294
+ int seqLen, int Dim)
295
+ {
296
+ int idx = blockIdx.x * blockDim.x + threadIdx.x;
297
+ int total = seqLen * Dim;
298
+ if (idx >= total) return;
299
+
300
+ // 计算 A 的索引
301
+ int dimId = idx % Dim;
302
+ int seqId = idx / Dim;
303
+
304
+ // A 索引
305
+ int a_idx = seqId * Dim + dimId;
306
+
307
+ // W 索引 (广播)
308
+ int w_idx = dimId;
309
+
310
+ int64_t a = A[a_idx];
311
+ int32_t w = W[w_idx];
312
+ int64_t r = rms[seqId];
313
+
314
+ __int128 prod = ( __int128)a * ( __int128)w; // 在 128 位里计算乘积,不溢出
315
+ __int128 qq = prod / (__int128)r; // 整数除法
316
+ __int128 rr = prod % (__int128)r; // 整数取模
317
+ if(rr < 0) {
318
+ qq = qq - 1;
319
+ rr = rr + r;
320
+ }
321
+
322
+ int64_t res = (int64_t)qq;
323
+
324
+ C[a_idx] = res;
325
+ }
326
+
327
+ extern "C" __global__ void rms_norm_kernel_64(
328
+ const int64_t* __restrict__ A,
329
+ const int64_t* __restrict__ W,
330
+ const int64_t* __restrict__ rms,
331
+ int64_t* __restrict__ C,
332
+ int seqLen, int Dim)
333
+ {
334
+ int idx = blockIdx.x * blockDim.x + threadIdx.x;
335
+ int total = seqLen * Dim;
336
+ if (idx >= total) return;
337
+
338
+ // 计算 A 的索引
339
+ int dimId = idx % Dim;
340
+ int seqId = idx / Dim;
341
+
342
+ // A 索引
343
+ int a_idx = seqId * Dim + dimId;
344
+
345
+ // W 索引 (广播)
346
+ int w_idx = dimId;
347
+
348
+ int64_t a = A[a_idx];
349
+ int64_t w = W[w_idx];
350
+ int64_t r = rms[seqId];
351
+
352
+ __int128 prod = ( __int128)a * ( __int128)w; // 在 128 位里计算乘积,不溢出
353
+ __int128 qq = prod / (__int128)r; // 整数除法
354
+ __int128 rr = prod % (__int128)r; // 整数取模
355
+ if(rr < 0) {
356
+ qq = qq - 1;
357
+ rr = rr + r;
358
+ }
359
+
360
+ int64_t res = (int64_t)qq;
361
+
362
+ C[a_idx] = res;
363
+ }
364
+
365
+ extern "C" void rms_norm_32(
366
+ const int64_t* A, const int32_t* W, const int64_t* rms, int64_t* C,
367
+ int seqLen, int Dim)
368
+ {
369
+ int total = seqLen * Dim;
370
+ int threads = 256;
371
+ int blocks = (total + threads - 1) / threads;
372
+
373
+ rms_norm_kernel_32<<<blocks, threads>>>(A, W, rms, C, seqLen, Dim);
374
+ }
375
+
376
+ extern "C" void rms_norm_64(
377
+ const int64_t* A, const int64_t* W, const int64_t* rms, int64_t* C,
378
+ int seqLen, int Dim)
379
+ {
380
+ int total = seqLen * Dim;
381
+ int threads = 256;
382
+ int blocks = (total + threads - 1) / threads;
383
+
384
+ rms_norm_kernel_64<<<blocks, threads>>>(A, W, rms, C, seqLen, Dim);
385
+ }
386
+
387
+ extern "C" __global__ void einsum_bshd_hdc_bshc_kernel(
388
+ const int64_t* q_nope, // [B, S, H, D]
389
+ const int32_t* wkv_b_1, // [H, D, C]
390
+ int64_t* out, // [B, S, H, C]
391
+ int64_t rescale,
392
+ int B, int S, int H, int D, int C)
393
+ {
394
+ int b = blockIdx.x; // batch
395
+ int s = blockIdx.y; // sequence
396
+ int h = blockIdx.z; // head
397
+ int c = threadIdx.x; // output channel
398
+
399
+ if (c >= C) return;
400
+
401
+ __int128_t sum = rescale / 2;
402
+ int q_base = ((b * S + s) * H + h) * D;
403
+ int w_base = h * D * C + c;
404
+ for (int d = 0; d < D; d++) {
405
+ // int w_idx = (h * D + d) * C + c;
406
+ sum += __int128_t(q_nope[q_base + d]) * __int128_t(wkv_b_1[w_base + d * C]);
407
+ }
408
+
409
+ // sum /= rescale;
410
+ int64_t sum2 = int64_t(sum >> rescale);
411
+
412
+ int out_idx = ((b * S + s) * H + h) * C + c;
413
+ out[out_idx] = sum2;
414
+ }
415
+
416
+ extern "C" void einsum_bshd_hdc_bshc(const int64_t* q_nope, const int32_t* wkv_b_1, int64_t* out,
417
+ int64_t rescale, int B, int S, int H, int D, int C) {
418
+
419
+ dim3 grid(B, S, H);
420
+ dim3 block(C);
421
+
422
+ einsum_bshd_hdc_bshc_kernel<<<grid, block>>>(
423
+ q_nope, wkv_b_1, out, rescale,
424
+ B, S, H, D, C);
425
+ }
426
+
427
+ extern "C" __global__ void einsum_bshc_btc_bsht_kernel(
428
+ const int64_t* __restrict__ A, // [B, S, H, C]
429
+ const int64_t* __restrict__ B, // [B, T, C]
430
+ int64_t* __restrict__ C, // [B, S, H, T]
431
+ int64_t rescale,
432
+ int Bsz, int S, int H, int T, int Cdim)
433
+ {
434
+ int b = blockIdx.x;
435
+ int s = blockIdx.y;
436
+ int h = blockIdx.z;
437
+ int t = threadIdx.x;
438
+
439
+ if (t >= T) return;
440
+
441
+ // 计算 A[b,s,h,:] 和 B[b,t,:] 的内积
442
+ __int128_t sum = rescale / 2;
443
+
444
+ int A_base = ((b * S + s) * H + h) * Cdim;
445
+ int B_base = (b * T + t) * Cdim;
446
+ for (int c = 0; c < Cdim; c++) {
447
+ // int idxB = (b * T + t) * Cdim + c;
448
+ sum += __int128_t(A[A_base + c]) * __int128_t(B[B_base + c]);
449
+ }
450
+
451
+ // sum /= rescale;
452
+ int64_t sum2 = int64_t(sum >> rescale);
453
+
454
+ int idxC = ((b * S + s) * H + h) * T + t;
455
+ C[idxC] = sum2;
456
+ }
457
+
458
+ extern "C" void einsum_bshc_btc_bsht(const int64_t* A, const int64_t* B, int64_t* C,
459
+ int64_t rescale, int Bsz, int S, int H, int T, int Cdim)
460
+ {
461
+ dim3 grid(Bsz, S, H);
462
+ dim3 block(T);
463
+
464
+ einsum_bshc_btc_bsht_kernel<<<grid, block>>>(
465
+ A, B, C, rescale,
466
+ Bsz, S, H, T, Cdim);
467
+ }
468
+
469
+ extern "C" __global__ void einsum_bsht_btc_bshc_kernel(
470
+ const int64_t* __restrict__ A,
471
+ const int64_t* __restrict__ B,
472
+ int64_t* __restrict__ C,
473
+ int64_t rescale,
474
+ int Bsz, int S, int H, int T, int Cdim)
475
+ {
476
+ int b = blockIdx.x;
477
+ int s = blockIdx.y;
478
+ int h = blockIdx.z;
479
+ int c = threadIdx.x;
480
+
481
+ if (c >= Cdim) return;
482
+
483
+ __int128_t sum = rescale / 2;
484
+
485
+ int A_base = ((b * S + s) * H + h) * T;
486
+ int B_base = b * T * Cdim + c;
487
+ for (int t = 0; t < T; ++t) {
488
+ // int idxB = (b * T + t) * Cdim + c;
489
+ sum += __int128_t(A[A_base + t]) * __int128_t(B[B_base + t * Cdim]);
490
+ }
491
+
492
+ // sum /= rescale;
493
+ int64_t sum2 = int64_t(sum >> rescale);
494
+
495
+ const int idxC = ((b * S + s) * H + h) * Cdim + c;
496
+ C[idxC] = sum2;
497
+ }
498
+
499
+ extern "C" void einsum_bsht_btc_bshc(
500
+ const int64_t* A, const int64_t* B, int64_t* C,
501
+ int64_t rescale, int Bsz, int S, int H, int T, int Cdim)
502
+ {
503
+ dim3 grid(Bsz, S, H);
504
+ dim3 block(Cdim);
505
+
506
+ einsum_bsht_btc_bshc_kernel<<<grid, block>>>(
507
+ A, B, C, rescale,
508
+ Bsz, S, H, T, Cdim);
509
+ }
510
+
511
+ extern "C" __global__ void einsum_bshc_hdc_bshd_kernel(
512
+ const int64_t* __restrict__ A,
513
+ const int32_t* __restrict__ B,
514
+ int64_t* __restrict__ C,
515
+ int64_t rescale,
516
+ int Bsz, int S, int H, int D, int Cdim)
517
+ {
518
+ int b = blockIdx.x;
519
+ int s = blockIdx.y;
520
+ int h = blockIdx.z;
521
+ int d = threadIdx.x;
522
+
523
+ if (d >= D) return;
524
+
525
+ __int128_t sum = 0;
526
+ int A_base = ((b * S + s) * H + h) * Cdim;
527
+ int B_base = (h * D + d) * Cdim;
528
+ for (int c = 0; c < Cdim; ++c) {
529
+ sum += __int128_t(A[A_base + c]) * __int128_t(B[B_base + c]);
530
+ }
531
+
532
+ // sum = (sum + rescale / 2) / rescale;
533
+ int64_t sum2 = int64_t(sum >> rescale);
534
+
535
+ const int idxC = ((b * S + s) * H + h) * D + d;
536
+ C[idxC] = sum2;
537
+ }
538
+
539
+ extern "C" void einsum_bshc_hdc_bshd(const int64_t* A, const int32_t* B, int64_t* C,
540
+ int64_t rescale, int Bsz, int S, int H, int D, int Cdim)
541
+ {
542
+ dim3 grid(Bsz, S, H);
543
+ dim3 block(D);
544
+
545
+ einsum_bshc_hdc_bshd_kernel<<<grid, block>>>(
546
+ A, B, C, rescale,
547
+ Bsz, S, H, D, Cdim
548
+ );
549
+ }
550
+
551
+ // static const int64_t LOG2E_Q32 = 6196328019ULL; // log2(e)*2^32
552
+ static const int64_t LOG2E_Q21 = 3025551; // log2(e)*2^21
553
+ static const int64_t LOG2E_Q19 = 756388; // log2(e)*2^19
554
+ // static const int LOG_TABLE_SIZE = 10;
555
+ static const int LOG_TABLE_SIZE = 8;
556
+ // static uint64_t EXP2_FRAC_LUT[256] = { /* 预生成:round(2^(i/256)*2^32) */ };
557
+ // static int64_t EXP2_FRAC_LUT[256] = { /* 预生成:round(2^(i/256)*2^32) */ };
558
+ // EXP2_FRAC_LUT = torch.zeros([256, ], dtype=torch.int64, device="cuda")
559
+
560
+ // extern "C" void softmax_q21_to_probs(const int64_t* R, int n, int64_t* P_q21) {
561
+ // int32_t Rmax = R[0];
562
+ // for (int i = 1; i < n; ++i) if (R[i] > Rmax) Rmax = R[i];
563
+
564
+ // // printf("Rmax: %d\n", Rmax);
565
+
566
+ // int64_t sumW = 0;
567
+ // static thread_local int64_t Wbuf[4096]; // 或动态分配
568
+ // for (int i = 0; i < n; ++i) {
569
+ // int64_t d = R[i] - Rmax; // Δ_i (<=0)
570
+
571
+ // // 剪裁:小于 -16 的差值近似为 0
572
+ // if (d < -(16 << 21))
573
+ // {
574
+ // Wbuf[i] = 0;
575
+ // continue;
576
+ // }
577
+
578
+ // // y = d * log2(e) / 2^21 (Q32)
579
+ // int64_t y = (d * LOG2E_Q21) >> 21;
580
+
581
+ // // printf("y: %ld\n", y);
582
+
583
+ // int64_t k = (-y) >> 21; // 整数部分 取正数(k > 0)
584
+ // int64_t f = (-y) & 0x1FFFFF; // 小数部分 取正数(Q21, f > 0)
585
+
586
+ // // printf("k: %ld, f: %ld\n", k, f);
587
+
588
+ // int64_t t = EXP2_FRAC_LUT[ f >> 13 ]; // 2^(frac(y)) in Q21, 13 = 21 - 8, 取小数部分 转换成整数之后的 高8位
589
+ // // int64_t t = 0;
590
+
591
+ // int64_t wi = (k >= 32) ? 0u : (t >> k); // 2^(-k) * t, 右移
592
+ // Wbuf[i] = wi;
593
+ // sumW += wi;
594
+ // }
595
+
596
+ // // 归一化到 Q21 概率
597
+ // for (int i = 0; i < n; ++i) {
598
+ // int64_t num = Wbuf[i] << 21; // 提升精度
599
+ // P_q21[i] = sumW ? (num / sumW) : 0;
600
+ // }
601
+ // }
602
+
603
+ // start of softmax_q21 ------------------------
604
+
605
+ extern "C" __global__ void softmax_kernel_q21(
606
+ int64_t* R, // [B, S, H, T]
607
+ int64_t* C, // [B, S, H, T]
608
+ int64_t LOG2E_Q21, int64_t* EXP2_FRAC_LUT,
609
+ int Bsz, int S, int H, int T)
610
+ {
611
+ int b = blockIdx.x;
612
+ int s = blockIdx.y;
613
+ int h = threadIdx.x;
614
+
615
+ if (h >= H) return;
616
+
617
+ int idxbase = ((b * S + s) * H + h) * T;
618
+
619
+ int64_t Rmax = R[idxbase];
620
+ for (int i = 1; i < T; ++i) if (R[idxbase + i] > Rmax) Rmax = R[idxbase + i];
621
+
622
+ int64_t sumW = 0;
623
+ for (int i = 0; i < T; i++) {
624
+ int64_t d = R[idxbase + i] - Rmax; // Δ_i (d <= 0)
625
+
626
+ // 剪裁:小于 -64 的差值近似为 0
627
+ if (d < -(64 << 21))
628
+ {
629
+ R[idxbase + i] = 0;
630
+ continue;
631
+ }
632
+
633
+ // y = d * log2(e) / 2^21 (Q21, y <= 0)
634
+ int64_t y = (d * LOG2E_Q21 + (1 << 20)) >> 21;
635
+
636
+ int64_t k = (-y) >> 21; // 整数部分 取正数(k > 0)
637
+ int64_t f = (-y) & 0x1FFFFF; // 小数部分 取正数(Q21)
638
+
639
+ // printf("k: %ld, f: %ld\n", k, f);
640
+
641
+ int64_t t = EXP2_FRAC_LUT[ f >> (21 - LOG_TABLE_SIZE) ]; // 2^(frac(y)) in Q21, 取小数部分 转换成整数之后的高 LOG_TABLE_SIZE 位
642
+ // int64_t t = 0;
643
+
644
+ int64_t wi = (k >= 64) ? 0 : (t >> k); // 2^(-k) * t, 右移
645
+
646
+ // if(b == 0 && s == 3 && h == 127) {
647
+ // printf("i: %d, d: %ld, y: %ld, k: %ld, f: %ld, t: %ld, wi: %ld\n", i, d, y, k, f, t, wi);
648
+ // }
649
+
650
+ R[idxbase + i] = wi;
651
+ sumW += wi;
652
+ }
653
+
654
+ // 归一化到 Q21 概率
655
+ for (int i = 0; i < T; ++i) {
656
+ int64_t num = R[idxbase + i] << 21; // 提升精度
657
+ C[idxbase + i] = sumW ? ((num + sumW / 2) / sumW) : 0;
658
+ }
659
+ }
660
+
661
+ extern "C" void softmax_q21(int64_t* R, int64_t* C, int64_t* EXP2_FRAC_LUT,
662
+ int Bsz, int S, int H, int T)
663
+ {
664
+ dim3 grid(Bsz, S);
665
+ dim3 block(H);
666
+
667
+ softmax_kernel_q21<<<grid, block>>>(
668
+ R, C, LOG2E_Q21, EXP2_FRAC_LUT,
669
+ Bsz, S, H, T);
670
+ }
671
+
672
+ extern "C" void softmax_init_q21(int64_t* EXP2_FRAC_LUT)
673
+ {
674
+ // printf("inited!\n");
675
+ int x2_21 = 1 << 21;
676
+ for(int i = 0; i < (1 << LOG_TABLE_SIZE); i++) {
677
+ // EXP2_FRAC_LUT[i] = uint64_t(std::pow(2, i / 256.0) * 4294967296);
678
+ EXP2_FRAC_LUT[i] = int64_t(std::pow(2, i * (-1.0f) / (1 << LOG_TABLE_SIZE)) * x2_21);
679
+ }
680
+ }
681
+ // -- end of softmax_q21 -----------------------------
682
+
683
+ // start of softmax_q19 ------------------------
684
+ extern "C" __global__ void softmax_kernel_q19(
685
+ int64_t* R, // [B, S, H, T]
686
+ int64_t* C, // [B, S, H, T]
687
+ int64_t LOG2E_Q19, int64_t* EXP2_FRAC_LUT,
688
+ int Bsz, int S, int H, int T)
689
+ {
690
+ int b = blockIdx.x;
691
+ int s = blockIdx.y;
692
+ int h = threadIdx.x;
693
+
694
+ if (h >= H) return;
695
+
696
+ int idxbase = ((b * S + s) * H + h) * T;
697
+
698
+ int64_t Rmax = R[idxbase];
699
+ for (int i = 1; i < T; ++i) if (R[idxbase + i] > Rmax) Rmax = R[idxbase + i];
700
+
701
+ int64_t sumW = 0;
702
+ for (int i = 0; i < T; i++) {
703
+ u_int64_t d = Rmax - R[idxbase + i]; // Δ_i (d >= 0)
704
+
705
+ // 剪裁:小于 -64 的差值近似为 0, d > (64 << 19)的条件 比 (k >= 64) 要宽松
706
+ if (d > (64 << 19))
707
+ {
708
+ R[idxbase + i] = 0;
709
+ continue;
710
+ }
711
+
712
+ // y = d * log2(e) / 2^19 (Q19, y >= 0)
713
+ int64_t y = int64_t((__int128_t(d) * __int128_t(LOG2E_Q19)) >> 19);
714
+
715
+ int64_t k = y >> 19; // 整数部分 取正数(k > 0)
716
+ int64_t f = y & 0x7FFFF; // 小数部分 取正数(Q19)
717
+
718
+ // printf("k: %ld, f: %ld\n", k, f);
719
+
720
+ int64_t t = EXP2_FRAC_LUT[ f >> (19 - LOG_TABLE_SIZE) ]; // 2^(frac(y)) in Q19, 取小数部分 转换成整数之后的高 LOG_TABLE_SIZE 位
721
+ // int64_t t = 0;
722
+
723
+ int64_t wi = (k >= 64) ? 0 : (t >> k); // 2^(-k) * t, 右移
724
+
725
+ // if(b == 0 && s == 2 && h == 2) {
726
+ // printf("i: %d, d: %ld, y: %ld, k: %ld, f: %ld, t: %ld, wi: %ld\n", i, d, y, k, f, t, wi);
727
+ // }
728
+
729
+ R[idxbase + i] = wi;
730
+ sumW += wi;
731
+ }
732
+
733
+ // if(b == 0 && s == 9 && h == 7) {
734
+ // printf("sumW: %ld\n", sumW);
735
+ // }
736
+
737
+ // 归一化到 Q19 概率
738
+ for (int i = 0; i < T; ++i) {
739
+ int64_t num = R[idxbase + i] << 19; // 提升精度
740
+ C[idxbase + i] = sumW ? (num / sumW) : 0;
741
+
742
+ // if(b == 0 && s == 9 && h == 7) {
743
+ // printf("i: %d, r: %ld, num: %ld, c: %ld\n", i, R[idxbase + i], num, C[idxbase + i]);
744
+ // }
745
+ }
746
+ }
747
+
748
+ extern "C" void softmax_q19(int64_t* R, int64_t* C, int64_t* EXP2_FRAC_LUT,
749
+ int Bsz, int S, int H, int T)
750
+ {
751
+ dim3 grid(Bsz, S);
752
+ dim3 block(H);
753
+
754
+ softmax_kernel_q19<<<grid, block>>>(
755
+ R, C, LOG2E_Q19, EXP2_FRAC_LUT,
756
+ Bsz, S, H, T);
757
+ }
758
+
759
+ extern "C" void softmax_init_q19(int64_t* EXP2_FRAC_LUT)
760
+ {
761
+ // printf("inited!\n");
762
+ int x2_19 = 1 << 19;
763
+ for(int i = 0; i < (1 << LOG_TABLE_SIZE); i++) {
764
+ // EXP2_FRAC_LUT[i] = uint64_t(std::pow(2, i / 256.0) * 4294967296);
765
+ EXP2_FRAC_LUT[i] = int64_t(std::pow(2, i * (-1.0f) / (1 << LOG_TABLE_SIZE)) * x2_19);
766
+ }
767
+ }
768
+ // -- end of softmax_q19 -----------------------------
769
+
770
+ // -- start of silu_q25 -----------------------------
771
+ static const int64_t LOG2E_Q25 = 48408813; // round(log2(e)*2^25)
772
+ static const int64_t exp2_25 = 33554432; // 1 << 25;
773
+ static const int64_t exp2_50 = 1125899906842624; // 1 << 50;
774
+
775
+ extern "C" __global__ void silu_kernel_q25(
776
+ int64_t* R, // [B, S, Dim]
777
+ int64_t* C, // [B, S, Dim]
778
+ int64_t LOG2E_Q25, int64_t* EXP2_FRAC_LUT_Q25,
779
+ int Bsz, int S, int Dim)
780
+ {
781
+ int b = blockIdx.x;
782
+ int s = blockIdx.y;
783
+ int d = blockIdx.z * blockDim.x + threadIdx.x;
784
+
785
+ if (d >= Dim) return;
786
+
787
+ int idx = (b * S + s) * Dim + d;
788
+ int64_t r = R[idx];
789
+
790
+ // 饱和区裁剪(可调阈值 64)
791
+ const int64_t LIM = (int64_t)64 << 25;
792
+ if (r >= LIM) // σ≈1 -> SiLU ~= x
793
+ {
794
+ C[idx] = r;
795
+ return;
796
+ }
797
+ if (r <= -LIM) // σ≈0 -> SiLU ~= 0
798
+ {
799
+ C[idx] = 0;
800
+ return;
801
+ }
802
+
803
+ // y = - x * log2(e) / 2^25 (Q25)
804
+ int64_t y = -int64_t((__int128_t(r) * __int128_t(LOG2E_Q25)) >> 25);
805
+
806
+ // u ≈ 2^y = e^{-x} = 2^k * 2^f (Q25)
807
+ int64_t u = 0;
808
+ int64_t k = y >> 25; // 整数部分
809
+ int64_t f = y & 0x1FFFFFF; // 小数部分 (Q25)
810
+ int64_t t = EXP2_FRAC_LUT_Q25[f >> (25 - LOG_TABLE_SIZE)]; // 2^(frac) in Q25
811
+ if (k > -63)
812
+ u = (k < 0) ? (t >> (-k)) : (t << k); // 一般 k<=0
813
+
814
+ // σ = 1 / (1 + u) (Q25)
815
+ int64_t q = exp2_50 / (exp2_25 + u); // Q25
816
+
817
+ // SiLU = x * σ : (Q25 * Q25) >> 25 → Q25
818
+ C[idx] = ((r * q) >> 25);
819
+ }
820
+
821
+ extern "C" void silu_q25(int64_t* R, int64_t* C, int64_t* EXP2_FRAC_LUT_25,
822
+ int Bsz, int S, int Dim)
823
+ {
824
+ dim3 grid(Bsz, S, (Dim + 255) / 256);
825
+ dim3 block(256, 1, 1);
826
+
827
+ silu_kernel_q25<<<grid, block>>>(
828
+ R, C, LOG2E_Q25, EXP2_FRAC_LUT_25,
829
+ Bsz, S, Dim);
830
+ }
831
+
832
+ extern "C" void silu_init_q25(int64_t* EXP2_FRAC_LUT)
833
+ {
834
+ int tableSize = 1 << LOG_TABLE_SIZE;
835
+ for(int i = 0; i < tableSize; i++) {
836
+ // EXP2_FRAC_LUT[i] = uint64_t(std::pow(2, i / 1024.0) * 2^25);
837
+ EXP2_FRAC_LUT[i] = int64_t(std::pow(2, i * 1.0f / tableSize) * exp2_25);
838
+ }
839
+ }
840
+
841
+ extern "C" __global__ void sigmoid_kernel_q25(
842
+ int64_t* R, // [B, S, Dim]
843
+ int64_t* C, // [B, S, Dim]
844
+ int64_t LOG2E_Q25, int64_t* EXP2_FRAC_LUT_Q25,
845
+ int Bsz, int S, int Dim)
846
+ {
847
+ int b = blockIdx.x;
848
+ int s = blockIdx.y;
849
+ int d = blockIdx.z * blockDim.x + threadIdx.x;
850
+
851
+ if (d >= Dim) return;
852
+
853
+ int idx = (b * S + s) * Dim + d;
854
+ int64_t r = R[idx];
855
+
856
+ // 饱和区裁剪(可调阈值 64)
857
+ const int64_t LIM = (int64_t)64 << 25;
858
+ if (r >= LIM) // σ≈1 -> SiLU ~= x
859
+ {
860
+ C[idx] = r;
861
+ return;
862
+ }
863
+ if (r <= -LIM) // σ≈0 -> SiLU ~= 0
864
+ {
865
+ C[idx] = 0;
866
+ return;
867
+ }
868
+
869
+ // y = - x * log2(e) / 2^25 (Q25)
870
+ int64_t y = -int64_t((__int128_t(r) * __int128_t(LOG2E_Q25)) >> 25);
871
+
872
+ // u ≈ 2^y = e^{-x} = 2^k * 2^f (Q25)
873
+ int64_t u = 0;
874
+ int64_t k = y >> 25; // 整数部分
875
+ int64_t f = y & 0x1FFFFFF; // 小数部分 (Q25)
876
+ int64_t t = EXP2_FRAC_LUT_Q25[f >> (25 - LOG_TABLE_SIZE)]; // 2^(frac) in Q25
877
+ if (k > -63)
878
+ u = (k < 0) ? (t >> (-k)) : (t << k); // 一般 k<=0
879
+
880
+ // σ = 1 / (1 + u) (Q25)
881
+ C[idx] = exp2_50 / (exp2_25 + u);
882
+ }
883
+
884
+ extern "C" void sigmoid_q25(int64_t* R, int64_t* C, int64_t* EXP2_FRAC_LUT_25,
885
+ int Bsz, int S, int Dim)
886
+ {
887
+ dim3 grid(Bsz, S, (Dim + 255) / 256);
888
+ dim3 block(256, 1, 1);
889
+
890
+ sigmoid_kernel_q25<<<grid, block>>>(
891
+ R, C, LOG2E_Q25, EXP2_FRAC_LUT_25,
892
+ Bsz, S, Dim);
893
+ }
894
+
895
+ // -- end of silu_q25 -----------------------------
896
+
897
+
898
+ // -- start of silu_q23 -----------------------------
899
+ static const int64_t LOG2E_Q23 = 12102203; // round(log2(e)*2^23)
900
+ static const int64_t exp2_23 = 8388608; // 1 << 23;
901
+ static const int64_t exp2_46 = 70368744177664; // 1 << 46;
902
+
903
+ extern "C" __global__ void silu_kernel_q23(
904
+ int64_t* R, // [B, S, Dim]
905
+ int64_t* C, // [B, S, Dim]
906
+ int64_t LOG2E_Q23, int64_t* EXP2_FRAC_LUT_Q23,
907
+ int Bsz, int S, int Dim)
908
+ {
909
+ int b = blockIdx.x;
910
+ int s = blockIdx.y;
911
+ int d = blockIdx.z * blockDim.x + threadIdx.x;
912
+
913
+ if (d >= Dim) return;
914
+
915
+ int idx = (b * S + s) * Dim + d;
916
+ int64_t r = R[idx];
917
+
918
+ // 饱和区裁剪(可调阈值 64)
919
+ const int64_t LIM = (int64_t)64 << 23;
920
+ if (r >= LIM) // σ≈1 -> SiLU ~= x
921
+ {
922
+ C[idx] = r;
923
+ return;
924
+ }
925
+ if (r <= -LIM) // σ≈0 -> SiLU ~= 0
926
+ {
927
+ C[idx] = 0;
928
+ return;
929
+ }
930
+
931
+ // y = - x * log2(e) / 2^23 (Q23)
932
+ int64_t y = int64_t((__int128_t(-r) * __int128_t(LOG2E_Q23)) >> 23);
933
+
934
+ // u ≈ 2^y = e^{-x} = 2^k * 2^f (Q23)
935
+ int64_t u = 0;
936
+ int64_t k = y >> 23; // 整数部分
937
+ int64_t f = y & 0x7FFFFF; // 小数部分 (Q23)
938
+ int64_t t = EXP2_FRAC_LUT_Q23[f >> (23 - LOG_TABLE_SIZE)]; // 2^(frac) in Q23
939
+ if (k > -63)
940
+ u = (k < 0) ? (t >> (-k)) : (t << k); // 一般 k<=0
941
+
942
+ // σ = 1 / (1 + u) (Q23)
943
+ int64_t q = exp2_46 / (exp2_23 + u); // Q23
944
+
945
+ // SiLU = x * σ : (Q23 * Q23) >> 23 → Q23
946
+ C[idx] = ((r * q) >> 23);
947
+ }
948
+
949
+ extern "C" void silu_q23(int64_t* R, int64_t* C, int64_t* EXP2_FRAC_LUT_23,
950
+ int Bsz, int S, int Dim)
951
+ {
952
+ dim3 grid(Bsz, S, (Dim + 255) / 256);
953
+ dim3 block(256, 1, 1);
954
+
955
+ silu_kernel_q23<<<grid, block>>>(
956
+ R, C, LOG2E_Q23, EXP2_FRAC_LUT_23,
957
+ Bsz, S, Dim);
958
+ }
959
+
960
+ extern "C" void silu_init_q23(int64_t* EXP2_FRAC_LUT)
961
+ {
962
+ int tableSize = 1 << LOG_TABLE_SIZE;
963
+ for(int i = 0; i < tableSize; i++) {
964
+ // EXP2_FRAC_LUT[i] = uint64_t(std::pow(2, i / 1024.0) * 2^23);
965
+ EXP2_FRAC_LUT[i] = int64_t(std::pow(2, i * 1.0f / tableSize) * exp2_23);
966
+ }
967
+ }
968
+
969
+ extern "C" __global__ void sigmoid_kernel_q23(
970
+ int64_t* R, // [B, S, Dim]
971
+ int64_t* C, // [B, S, Dim]
972
+ int64_t LOG2E_Q23, int64_t* EXP2_FRAC_LUT_Q23,
973
+ int Bsz, int S, int Dim)
974
+ {
975
+ int b = blockIdx.x;
976
+ int s = blockIdx.y;
977
+ int d = blockIdx.z * blockDim.x + threadIdx.x;
978
+
979
+ if (d >= Dim) return;
980
+
981
+ int idx = (b * S + s) * Dim + d;
982
+ int64_t r = R[idx];
983
+
984
+ // 饱和区裁剪(可调阈值 64)
985
+ const int64_t LIM = (int64_t)64 << 23;
986
+ if (r >= LIM) // σ≈1 -> SiLU ~= x
987
+ {
988
+ printf("r: %ld >= LIM", r);
989
+ C[idx] = r;
990
+ return;
991
+ }
992
+ if (r <= -LIM) // σ≈0 -> SiLU ~= 0
993
+ {
994
+ printf("r: %ld <= -LIM", r);
995
+ C[idx] = 0;
996
+ return;
997
+ }
998
+
999
+ // y = - x * log2(e) / 2^23 (Q23)
1000
+ int64_t y = int64_t((__int128_t(-r) * __int128_t(LOG2E_Q23)) >> 23);
1001
+
1002
+ // u ≈ 2^y = e^{-x} = 2^k * 2^f (Q23)
1003
+ int64_t u = 0;
1004
+ int64_t k = y >> 23; // 整数部分
1005
+ int64_t f = y & 0x7FFFFF; // 小数部分 (Q23)
1006
+ int64_t t = EXP2_FRAC_LUT_Q23[f >> (23 - LOG_TABLE_SIZE)]; // 2^(frac) in Q23
1007
+ if (k > -63)
1008
+ u = (k < 0) ? (t >> (-k)) : (t << k); // 一般 k<=0
1009
+
1010
+ // if(s == 0 && d == 4)
1011
+ // {
1012
+ // printf("s: %d, d: %d, x: %ld, y: %ld, k: %ld, f: %ld, t: %ld, u: %ld\n", s, d, r, y, k, f, t, u);
1013
+ // }
1014
+
1015
+ // σ = 1 / (1 + u) (Q23)
1016
+ C[idx] = exp2_46 / (exp2_23 + u);
1017
+ }
1018
+
1019
+ extern "C" void sigmoid_q23(int64_t* R, int64_t* C, int64_t* EXP2_FRAC_LUT_23,
1020
+ int Bsz, int S, int Dim)
1021
+ {
1022
+ dim3 grid(Bsz, S, (Dim + 255) / 256);
1023
+ dim3 block(256, 1, 1);
1024
+
1025
+ sigmoid_kernel_q23<<<grid, block>>>(
1026
+ R, C, LOG2E_Q23, EXP2_FRAC_LUT_23,
1027
+ Bsz, S, Dim);
1028
+ }
1029
+
1030
+ // -- end of silu_q23 -----------------------------
inference/kernel.py ADDED
@@ -0,0 +1,724 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ import math
4
+ import random
5
+ import torch
6
+ import ctypes
7
+ import triton
8
+ import triton.language as tl
9
+ from triton import Config
10
+
11
+
12
+
13
+ @triton.jit
14
+ def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
15
+ """
16
+ Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factor in `s_ptr`.
17
+
18
+ Args:
19
+ x_ptr (triton.Pointer): Pointer to the input tensor.
20
+ y_ptr (triton.Pointer): Pointer to the output tensor where quantized values will be stored.
21
+ s_ptr (triton.Pointer): Pointer to the output tensor where scaling factors will be stored.
22
+ BLOCK_SIZE (tl.constexpr): The size of the block to be processed by each program instance.
23
+
24
+ Returns:
25
+ None
26
+ """
27
+ pid = tl.program_id(axis=0)
28
+ offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
29
+ x = tl.load(x_ptr + offs).to(tl.float32)
30
+ s = tl.max(tl.abs(x)) / 448.
31
+ y = x / s
32
+ y = y.to(y_ptr.dtype.element_ty)
33
+ tl.store(y_ptr + offs, y)
34
+ tl.store(s_ptr + pid, s)
35
+
36
+ # 把 张量 x 进行 量化
37
+ def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
38
+ """
39
+ Quantizes the input tensor `x` using block-wise quantization.
40
+
41
+ Args:
42
+ x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`.
43
+ block_size (int, optional): The size of the blocks to be used for quantization. Default is 128.
44
+
45
+ Returns:
46
+ Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
47
+ - The quantized tensor with dtype `torch.float8_e4m3fn`.
48
+ - A tensor of scaling factors with dtype `torch.float32`.
49
+ """
50
+ assert x.is_contiguous(), 'Input tensor must be contiguous'
51
+ assert x.size(-1) % block_size == 0, f'Last dimension size must be divisible by block_size (block_size={block_size})'
52
+ # 创建两个张量:一个形状与x 一致且dtype为FP8的张量y;一个是专门储存scale因子的张量s,依旧是每128维储存一个scale因子
53
+ # (按照上述代码来看,s的张量形状为(2, 3, 7168 // 128)=(2, 3, 56),数据类型为FP32)。
54
+ y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
55
+ s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32)
56
+ # 之后的两行代码,便涉及到了Triton Kernel的调度计算。Triton是一个专门用于优化GPU计算的编程框架。内核调度(Kernel Scheduling)指的是
57
+ # 如何将计算任务分配给GPU上的计算单元(SMs-Streaming Multiprocessors)。内核(kernel)指的是要求在 GPU 上并行执行的那段代码(也可以说是计算任务)。
58
+ # 众所周知,GPU并不像CPU那样串行计算,而是同时运行多个计算块(blocks),每个 block又包含多个线程,它们并行执行任务,以提高计算效率。
59
+ # grid 决定多少个计算block被调度到 GPU 上。这里调用了triton.cdiv(x.numel(), meta['BLOCK_SIZE']) 来计算需要多少个 blocks。
60
+ # x.numel()是输入x张量里元素的个数,在本例中为2×3×7168个。 triton.cdiv()负责作向上取整的除法,以确保整个张量都能被块覆盖。
61
+ # meta['BLOCK_SIZE']=128 ,于是可知grid为(2×3×7168/128, )=(336, ) ,即最终会划分为336块blocks进行并行计算。
62
+ grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']), )
63
+ act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size)
64
+ return y, s
65
+
66
+
67
+ @triton.jit
68
+ def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
69
+ """
70
+ Dequantizes weights using the provided scaling factors and stores the result.
71
+
72
+ Args:
73
+ x_ptr (tl.pointer): Pointer to the quantized weights.
74
+ s_ptr (tl.pointer): Pointer to the scaling factors.
75
+ y_ptr (tl.pointer): Pointer to the output buffer for dequantized weights.
76
+ M (int): Number of rows in the weight matrix.
77
+ N (int): Number of columns in the weight matrix.
78
+ BLOCK_SIZE (tl.constexpr): Size of the block for tiling.
79
+
80
+ Returns:
81
+ None
82
+ """
83
+ pid_m = tl.program_id(axis=0)
84
+ pid_n = tl.program_id(axis=1)
85
+ n = tl.cdiv(N, BLOCK_SIZE)
86
+ offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
87
+ offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
88
+ offs = offs_m[:, None] * N + offs_n[None, :]
89
+ mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
90
+ x = tl.load(x_ptr + offs, mask=mask).to(tl.float32)
91
+ s = tl.load(s_ptr + pid_m * n + pid_n)
92
+ y = x * s
93
+ tl.store(y_ptr + offs, y, mask=mask)
94
+
95
+
96
+ def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor:
97
+ """
98
+ Dequantizes the given weight tensor using the provided scale tensor.
99
+
100
+ Args:
101
+ x (torch.Tensor): The quantized weight tensor of shape (M, N).
102
+ s (torch.Tensor): The scale tensor of shape (M, N).
103
+ block_size (int, optional): The block size to use for dequantization. Defaults to 128.
104
+
105
+ Returns:
106
+ torch.Tensor: The dequantized weight tensor of the same shape as `x`.
107
+
108
+ Raises:
109
+ AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2.
110
+ """
111
+ assert x.is_contiguous() and s.is_contiguous(), 'Input tensors must be contiguous'
112
+ assert x.dim() == 2 and s.dim() == 2, 'Input tensors must have 2 dimensions'
113
+ M, N = x.size()
114
+ y = torch.empty_like(x, dtype=torch.get_default_dtype())
115
+ grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE']))
116
+ weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)
117
+ return y
118
+
119
+
120
+ fp8_gemm_configs = [
121
+ Config({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': 128}, num_stages=num_stages, num_warps=8)
122
+ for block_m in [16, 32, 64] for block_n in [32, 64, 128] for num_stages in [3, 4, 5, 6]
123
+ ]
124
+
125
+ @triton.autotune(configs=fp8_gemm_configs, key=['N', 'K'])
126
+ @triton.jit
127
+ def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr,
128
+ a_s_ptr, b_s_ptr,
129
+ M, N: tl.constexpr, K: tl.constexpr,
130
+ BLOCK_SIZE_M: tl.constexpr,
131
+ BLOCK_SIZE_N: tl.constexpr,
132
+ BLOCK_SIZE_K: tl.constexpr):
133
+ """
134
+ Performs a matrix multiplication operation on FP8 matrices with scaling factors.
135
+
136
+ Args:
137
+ a_ptr (tl.tensor): Pointer to the first input matrix A.
138
+ b_ptr (tl.tensor): Pointer to the second input matrix B.
139
+ c_ptr (tl.tensor): Pointer to the output matrix C.
140
+ a_s_ptr (tl.tensor): Pointer to the scaling factors for matrix A.
141
+ b_s_ptr (tl.tensor): Pointer to the scaling factors for matrix B.
142
+ M (int): Number of rows in matrix A and C.
143
+ N (tl.constexpr): Number of columns in matrix B and C.
144
+ K (tl.constexpr): Number of columns in matrix A and rows in matrix B.
145
+ BLOCK_SIZE_M (tl.constexpr): Block size for the M dimension.
146
+ BLOCK_SIZE_N (tl.constexpr): Block size for the N dimension.
147
+ BLOCK_SIZE_K (tl.constexpr): Block size for the K dimension.
148
+
149
+ Returns:
150
+ None
151
+ """
152
+ pid_m = tl.program_id(axis=0)
153
+ pid_n = tl.program_id(axis=1)
154
+ k = tl.cdiv(K, BLOCK_SIZE_K)
155
+ offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
156
+ offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
157
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
158
+ a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :]
159
+ b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None]
160
+ a_s_ptrs = a_s_ptr + offs_m * k
161
+ b_s_ptrs = b_s_ptr + (offs_n // BLOCK_SIZE_K) * k
162
+
163
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
164
+ for i in range(k):
165
+ a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, other=0.0)
166
+ b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, other=0.0)
167
+ a_s = tl.load(a_s_ptrs)
168
+ b_s = tl.load(b_s_ptrs)
169
+ accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
170
+ a_ptrs += BLOCK_SIZE_K
171
+ b_ptrs += BLOCK_SIZE_K
172
+ a_s_ptrs += 1
173
+ b_s_ptrs += 1
174
+ c = accumulator.to(c_ptr.dtype.element_ty)
175
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
176
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
177
+ c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :]
178
+ mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
179
+ tl.store(c_ptrs, c, mask=mask)
180
+
181
+ # FP8通用矩阵乘法
182
+ def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor):
183
+ """
184
+ Perform a matrix multiplication using FP8 precision.
185
+
186
+ Args:
187
+ a (torch.Tensor): The first input matrix, must be contiguous.
188
+ a_s (torch.Tensor): The scaling factor for the first input matrix, must be contiguous.
189
+ b (torch.Tensor): The second input matrix, must be contiguous.
190
+ b_s (torch.Tensor): The scaling factor for the second input matrix, must be contiguous.
191
+
192
+ Returns:
193
+ torch.Tensor: The result of the matrix multiplication.
194
+ """
195
+ assert a.is_contiguous() and b.is_contiguous(), 'Input tensors must be contiguous'
196
+ assert a_s.is_contiguous() and b_s.is_contiguous(), 'Scaling factor tensors must be contiguous'
197
+ K = a.size(-1)
198
+ M = a.numel() // K
199
+ N = b.size(0)
200
+ c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
201
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), triton.cdiv(N, META['BLOCK_SIZE_N']))
202
+ fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K)
203
+ return c
204
+
205
+ # 加载 CUDA 动态库
206
+ lib = ctypes.CDLL("./libint64gemm.so")
207
+
208
+ # 定义参数类型
209
+ lib.int64_64_bmm_broadcast_launcher.argtypes = [
210
+ ctypes.c_void_p, # A
211
+ ctypes.c_void_p, # B
212
+ ctypes.c_void_p, # C
213
+ ctypes.c_void_p, # R
214
+ ctypes.c_longlong, ctypes.c_longlong, ctypes.c_longlong,
215
+ ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int
216
+ ]
217
+
218
+ lib.int64_32_bmm_broadcast_launcher.argtypes = [
219
+ ctypes.c_void_p, # A
220
+ ctypes.c_void_p, # B
221
+ ctypes.c_void_p, # C
222
+ ctypes.c_void_p, # R
223
+ ctypes.c_longlong, ctypes.c_longlong, ctypes.c_longlong,
224
+ ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int
225
+ ]
226
+
227
+ lib.complex_int64_mul.argtypes = [
228
+ ctypes.c_void_p, # A
229
+ ctypes.c_void_p, # B
230
+ ctypes.c_void_p, # C
231
+ ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int
232
+ ]
233
+
234
+ lib.rms_norm_32.argtypes = [
235
+ ctypes.c_void_p, # A
236
+ ctypes.c_void_p, # W
237
+ ctypes.c_void_p, # rms
238
+ ctypes.c_void_p, # C
239
+ ctypes.c_int, ctypes.c_int
240
+ ]
241
+
242
+ lib.rms_norm_64.argtypes = [
243
+ ctypes.c_void_p, # A
244
+ ctypes.c_void_p, # W
245
+ ctypes.c_void_p, # rms
246
+ ctypes.c_void_p, # C
247
+ ctypes.c_int, ctypes.c_int
248
+ ]
249
+
250
+ lib.einsum_bshd_hdc_bshc.argtypes = [
251
+ ctypes.c_void_p, # A
252
+ ctypes.c_void_p, # B
253
+ ctypes.c_void_p, # C
254
+ ctypes.c_longlong,
255
+ ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int
256
+ ]
257
+
258
+ lib.einsum_bshc_btc_bsht.argtypes = [
259
+ ctypes.c_void_p, # A
260
+ ctypes.c_void_p, # B
261
+ ctypes.c_void_p, # C
262
+ ctypes.c_longlong,
263
+ ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int
264
+ ]
265
+
266
+ lib.einsum_bsht_btc_bshc.argtypes = [
267
+ ctypes.c_void_p, # A
268
+ ctypes.c_void_p, # B
269
+ ctypes.c_void_p, # C
270
+ ctypes.c_longlong,
271
+ ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int
272
+ ]
273
+
274
+ lib.einsum_bshc_hdc_bshd.argtypes = [
275
+ ctypes.c_void_p, # A
276
+ ctypes.c_void_p, # B
277
+ ctypes.c_void_p, # C
278
+ ctypes.c_longlong,
279
+ ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int
280
+ ]
281
+
282
+ lib.softmax_q21.argtypes = [
283
+ ctypes.c_void_p,
284
+ ctypes.c_void_p,
285
+ ctypes.c_void_p,
286
+ ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int
287
+ ]
288
+
289
+ lib.softmax_q19.argtypes = [
290
+ ctypes.c_void_p,
291
+ ctypes.c_void_p,
292
+ ctypes.c_void_p,
293
+ ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int
294
+ ]
295
+
296
+ lib.softmax_init_q21.argtypes = [
297
+ ctypes.c_void_p
298
+ ]
299
+
300
+ lib.softmax_init_q19.argtypes = [
301
+ ctypes.c_void_p
302
+ ]
303
+
304
+ lib.silu_q25.argtypes = [
305
+ ctypes.c_void_p,
306
+ ctypes.c_void_p,
307
+ ctypes.c_void_p,
308
+ ctypes.c_int, ctypes.c_int, ctypes.c_int
309
+ ]
310
+
311
+ lib.sigmoid_q25.argtypes = [
312
+ ctypes.c_void_p,
313
+ ctypes.c_void_p,
314
+ ctypes.c_void_p,
315
+ ctypes.c_int, ctypes.c_int, ctypes.c_int
316
+ ]
317
+
318
+ lib.silu_init_q25.argtypes = [
319
+ ctypes.c_void_p
320
+ ]
321
+
322
+ lib.silu_q23.argtypes = [
323
+ ctypes.c_void_p,
324
+ ctypes.c_void_p,
325
+ ctypes.c_void_p,
326
+ ctypes.c_int, ctypes.c_int, ctypes.c_int
327
+ ]
328
+
329
+ lib.sigmoid_q23.argtypes = [
330
+ ctypes.c_void_p,
331
+ ctypes.c_void_p,
332
+ ctypes.c_void_p,
333
+ ctypes.c_int, ctypes.c_int, ctypes.c_int
334
+ ]
335
+
336
+ lib.silu_init_q23.argtypes = [
337
+ ctypes.c_void_p
338
+ ]
339
+
340
+
341
+ def int64_bmm_broadcast(A: torch.Tensor, B: torch.Tensor, a_rescale, b_rescale, c_rescale) -> tuple[torch.Tensor]:
342
+ """
343
+ int64 批量矩阵乘法: (B, M, K) x (N, K) -> (B, M, N)
344
+ """
345
+ global lib
346
+
347
+ assert A.dtype == torch.int64
348
+ # and B.dtype == torch.int64
349
+ assert A.is_cuda and B.is_cuda
350
+ Bdim, M, K = A.shape
351
+ N, K2 = B.shape
352
+ assert K2 == K
353
+
354
+ C = torch.empty((Bdim, M, N), dtype=torch.int64, device="cuda")
355
+ R = torch.empty((Bdim, M, N), dtype=torch.int64, device="cuda")
356
+
357
+ if B.dtype == torch.int64:
358
+ lib.int64_64_bmm_broadcast_launcher(
359
+ A.data_ptr(), B.data_ptr(), C.data_ptr(), R.data_ptr(),
360
+ a_rescale, b_rescale, c_rescale,
361
+ Bdim, M, K, N
362
+ )
363
+ elif B.dtype == torch.int32:
364
+ lib.int64_32_bmm_broadcast_launcher(
365
+ A.data_ptr(), B.data_ptr(), C.data_ptr(), R.data_ptr(),
366
+ a_rescale, b_rescale, c_rescale,
367
+ Bdim, M, K, N
368
+ )
369
+ else:
370
+ print(f'Unsupported B type: {B.dtype}')
371
+ return (C, R)
372
+
373
+ def complex_int64_mul_broadcast(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
374
+ """
375
+ int64 复数逐元素乘法
376
+ """
377
+ global lib
378
+
379
+ # print(f'A type: {A.dtype}, B type: {B.dtype}')
380
+ assert A.dtype == torch.int64 and B.dtype == torch.int64
381
+ assert A.is_cuda and B.is_cuda
382
+
383
+ batch = A.shape[0]
384
+ seqLen = A.shape[1]
385
+ head = A.shape[2]
386
+ headDim = A.shape[3]
387
+
388
+ C = torch.zeros(A.shape, dtype=torch.int64, device=A.device)
389
+
390
+ lib.complex_int64_mul(
391
+ A.data_ptr(), B.data_ptr(), C.data_ptr(),
392
+ # high_rescale, row_rescale,
393
+ batch, seqLen, head, headDim)
394
+
395
+ return C
396
+
397
+ def einsum_bshd_hdc_bshc(A: torch.Tensor, B: torch.Tensor, rescale) -> torch.Tensor:
398
+ global lib
399
+
400
+ assert A.shape[2] == B.shape[0] and A.shape[3] == B.shape[1]
401
+ assert A.is_cuda and B.is_cuda
402
+
403
+ Batch = A.shape[0]
404
+ S = A.shape[1]
405
+ H = A.shape[2]
406
+ D = A.shape[3]
407
+ Cp = B.shape[2]
408
+
409
+ C = torch.zeros([Batch, S, H, Cp], dtype=torch.int64, device=A.device)
410
+
411
+ lib.einsum_bshd_hdc_bshc(A.data_ptr(), B.data_ptr(), C.data_ptr(),
412
+ # (1 << rescale), Batch, S, H, D, Cp)
413
+ rescale, Batch, S, H, D, Cp)
414
+
415
+ return C
416
+
417
+ def einsum_bshc_btc_bsht(A: torch.Tensor, B: torch.Tensor, rescale) -> torch.Tensor:
418
+ global lib
419
+
420
+ Bsz = A.shape[0]
421
+ S = A.shape[1]
422
+ H = A.shape[2]
423
+ Cdim = A.shape[3]
424
+ T = B.shape[1]
425
+
426
+ assert Bsz == B.shape[0] and Cdim == B.shape[2]
427
+ assert A.is_cuda and B.is_cuda
428
+
429
+ C = torch.zeros([Bsz, S, H, T], dtype=torch.int64, device=A.device)
430
+
431
+ lib.einsum_bshc_btc_bsht(A.data_ptr(), B.data_ptr(), C.data_ptr(),
432
+ # (1 << rescale), Bsz, S, H, T, Cdim)
433
+ rescale, Bsz, S, H, T, Cdim)
434
+
435
+ return C
436
+
437
+ def einsum_bsht_btc_bshc(A: torch.Tensor, B: torch.Tensor, rescale) -> torch.Tensor:
438
+ global lib
439
+
440
+ Bsz = A.shape[0]
441
+ S = A.shape[1]
442
+ H = A.shape[2]
443
+ T = A.shape[3]
444
+ Cdim = B.shape[2]
445
+
446
+ assert Bsz == B.shape[0] and T == B.shape[1]
447
+ assert A.is_cuda and B.is_cuda
448
+
449
+ C = torch.zeros([Bsz, S, H, Cdim], dtype=torch.int64, device=A.device)
450
+
451
+ lib.einsum_bsht_btc_bshc(A.data_ptr(), B.data_ptr(), C.data_ptr(),
452
+ # (1 << rescale), Bsz, S, H, T, Cdim)
453
+ rescale, Bsz, S, H, T, Cdim)
454
+
455
+ return C
456
+
457
+ def einsum_bshc_hdc_bshd(A: torch.Tensor, B: torch.Tensor, rescale) -> torch.Tensor:
458
+ global lib
459
+
460
+ Bsz = A.shape[0]
461
+ S = A.shape[1]
462
+ H = A.shape[2]
463
+ D = B.shape[1]
464
+ Cdim = A.shape[3]
465
+
466
+ assert H == B.shape[0] and Cdim == B.shape[2]
467
+ assert A.is_cuda and B.is_cuda
468
+
469
+ C = torch.zeros([Bsz, S, H, D], dtype=torch.int64, device=A.device)
470
+
471
+ lib.einsum_bshc_hdc_bshd(A.data_ptr(), B.data_ptr(), C.data_ptr(),
472
+ # (1 << rescale), Bsz, S, H, D, Cdim)
473
+ rescale, Bsz, S, H, D, Cdim)
474
+
475
+ return C
476
+
477
+ def int64_RMS0(A: torch.Tensor, eps: int, dim: int) -> torch.Tensor:
478
+ assert A.dtype == torch.int64
479
+ assert A.ndim == 1
480
+
481
+ N = A.shape[0]
482
+
483
+ # 初始化累加器
484
+ acc = eps
485
+
486
+ for i in range(0, N):
487
+ a = A[i].item()
488
+ acc += a * a
489
+
490
+ acc = acc // dim
491
+
492
+ res1 = math.isqrt(acc)
493
+
494
+ return res1
495
+
496
+ # x 的 scale 为 2 ** 31,范围为 0 - 2^31
497
+ # weight的scale 为 2 ** 21, 范围为 2^5 - 2^20
498
+ # rms 的 scale 为 2 ** 31
499
+ # 返回的结果 scale 为 2 ** 21,31 + 21 - 31 = 21
500
+ @triton.jit
501
+ def int64_rms_norm_kernel(
502
+ A_ptr, W_ptr, C_ptr, RMS_ptr,
503
+ N,
504
+ batch_stride_a, batch_stride_c,
505
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr
506
+ ):
507
+ pid_m = tl.program_id(0)
508
+
509
+ for i in range(0, N):
510
+ a_ptrs = A_ptr + pid_m * batch_stride_a + i
511
+ w_ptrs = W_ptr + i
512
+ rms_ptrs = RMS_ptr + pid_m
513
+ a = tl.load(a_ptrs, mask=None)
514
+ w = tl.load(w_ptrs, mask=None)
515
+ rms = tl.load(rms_ptrs, mask=None)
516
+
517
+ res = a * w // rms
518
+
519
+ prod = a * w
520
+ tl.device_assert(prod > -2 ** 62 and prod < 2 ** 62, "Integer overflow risk!!!")
521
+
522
+ c_ptrs = C_ptr + pid_m * batch_stride_c + i
523
+ tl.store(c_ptrs, res, mask=None)
524
+
525
+ rms = torch.empty((500, ), dtype=torch.int64, device='cpu')
526
+ rms_gpu = torch.empty((500, ), dtype=torch.int64, device='cuda')
527
+
528
+ def RMS_Norm_int64(A: torch.Tensor, W: torch.Tensor, eps, dim) -> torch.Tensor:
529
+ global lib
530
+ global rms
531
+ global rms_gpu
532
+
533
+ assert A.dtype == torch.int64
534
+ assert A.is_cuda and W.is_cuda
535
+ assert A.ndim == 2
536
+
537
+ M, N = A.shape
538
+
539
+ for i in range(M):
540
+ rms[i] = int64_RMS0(A[i], eps, dim)
541
+
542
+ rms_gpu.copy_(rms)
543
+ C = torch.empty((M, N), dtype=torch.int64, device=A.device)
544
+
545
+ if W.dtype == torch.int32:
546
+ lib.rms_norm_32(A.data_ptr(), W.data_ptr(), rms_gpu.data_ptr(), C.data_ptr(), M, N)
547
+ else:
548
+ lib.rms_norm_64(A.data_ptr(), W.data_ptr(), rms_gpu.data_ptr(), C.data_ptr(), M, N)
549
+
550
+ return (C, rms)
551
+
552
+
553
+ def saveTensor(fileName, t):
554
+ with open(fileName, "w", encoding="utf-8") as f:
555
+ # for row in tensor:
556
+ # vs = [str(v.item()) for v in row]
557
+ # ss = ' '.join(vs) + '\n'
558
+ # f.write(ss)
559
+ t = t.detach()
560
+ if t.device.type != "cpu":
561
+ t = t.cpu()
562
+ t = t.contiguous()
563
+ with open(fileName, "wb") as f:
564
+ # .numpy() -> bytes(C-order)
565
+ f.write(t.numpy().tobytes(order="C"))
566
+
567
+ EXP2_FRAC_LUT_Q21 = None
568
+ # LOG_TABLE_SIZE = 10
569
+ LOG_TABLE_SIZE = 8
570
+
571
+ def softmax_init_q21():
572
+ global lib
573
+ global EXP2_FRAC_LUT_Q21
574
+
575
+ EXP2_FRAC_LUT0 = torch.zeros((2 ** LOG_TABLE_SIZE, ), dtype=torch.int64, device="cpu")
576
+ lib.softmax_init_q21(EXP2_FRAC_LUT0.data_ptr())
577
+ # print(EXP2_FRAC_LUT0[619])
578
+
579
+ EXP2_FRAC_LUT_Q21 = EXP2_FRAC_LUT0.cuda()
580
+
581
+ EXP2_FRAC_LUT_Q19 = None
582
+ def softmax_init_q19():
583
+ global lib
584
+ global EXP2_FRAC_LUT_Q19
585
+
586
+ EXP2_FRAC_LUT0 = torch.zeros((2 ** LOG_TABLE_SIZE, ), dtype=torch.int64, device="cpu")
587
+ lib.softmax_init_q19(EXP2_FRAC_LUT0.data_ptr())
588
+ # print(EXP2_FRAC_LUT0[619])
589
+
590
+ EXP2_FRAC_LUT_Q19 = EXP2_FRAC_LUT0.cuda()
591
+ # saveTensor(f'zkdata/softmax_q19_table.bin', EXP2_FRAC_LUT0.cpu())
592
+
593
+
594
+
595
+ def softmax_q21(R: torch.Tensor, C: torch.Tensor):
596
+ global lib
597
+ global EXP2_FRAC_LUT_Q21
598
+
599
+ assert R.is_cuda and C.is_cuda
600
+
601
+ # print(EXP2_FRAC_LUT_Q21)
602
+ Bsz = R.shape[0]
603
+ S = R.shape[1]
604
+ H = R.shape[2]
605
+ T = R.shape[3]
606
+ lib.softmax_q21(R.data_ptr(), C.data_ptr(), EXP2_FRAC_LUT_Q21.data_ptr(), Bsz, S, H, T)
607
+
608
+ def softmax_q19(R: torch.Tensor, C: torch.Tensor):
609
+ global lib
610
+ global EXP2_FRAC_LUT_Q19
611
+
612
+ assert R.is_cuda and C.is_cuda
613
+
614
+ # print(EXP2_FRAC_LUT_Q19)
615
+ Bsz = R.shape[0]
616
+ S = R.shape[1]
617
+ H = R.shape[2]
618
+ T = R.shape[3]
619
+ lib.softmax_q19(R.data_ptr(), C.data_ptr(), EXP2_FRAC_LUT_Q19.data_ptr(), Bsz, S, H, T)
620
+
621
+
622
+ # start of silu_q25 ---------------------------------
623
+ EXP2_FRAC_LUT_Q25 = None
624
+
625
+ def silu_init_q25():
626
+ global lib
627
+ global EXP2_FRAC_LUT_Q25
628
+
629
+ EXP2_FRAC_LUT0 = torch.zeros((2 ** LOG_TABLE_SIZE, ), dtype=torch.int64, device="cpu")
630
+ lib.silu_init_q25(EXP2_FRAC_LUT0.data_ptr())
631
+ # print(EXP2_FRAC_LUT0[619])
632
+
633
+ EXP2_FRAC_LUT_Q25 = EXP2_FRAC_LUT0.cuda()
634
+
635
+ def silu_q25(R: torch.Tensor, C: torch.Tensor):
636
+ global lib
637
+ global EXP2_FRAC_LUT_Q25
638
+
639
+ # print(EXP2_FRAC_LUT_Q25)
640
+ Bsz = R.shape[0]
641
+ S = R.shape[1]
642
+ Dim = R.shape[2]
643
+ lib.silu_q25(R.data_ptr(), C.data_ptr(), EXP2_FRAC_LUT_Q25.data_ptr(), Bsz, S, Dim)
644
+
645
+ def sigmoid_q25(R: torch.Tensor, C: torch.Tensor):
646
+ global lib
647
+ global EXP2_FRAC_LUT_Q25
648
+
649
+ Bsz = R.shape[0]
650
+ S = R.shape[1]
651
+ Dim = R.shape[2]
652
+ lib.sigmoid_q25(R.data_ptr(), C.data_ptr(), EXP2_FRAC_LUT_Q25.data_ptr(), Bsz, S, Dim)
653
+ # end of silu_q25 ---------------------------------
654
+
655
+ # start of silu_q23 ---------------------------------
656
+ EXP2_FRAC_LUT_Q23 = None
657
+
658
+ def silu_init_q23():
659
+ global lib
660
+ global EXP2_FRAC_LUT_Q23
661
+
662
+ EXP2_FRAC_LUT0 = torch.zeros((2 ** LOG_TABLE_SIZE, ), dtype=torch.int64, device="cpu")
663
+ lib.silu_init_q23(EXP2_FRAC_LUT0.data_ptr())
664
+ # print(EXP2_FRAC_LUT0[619])
665
+
666
+ EXP2_FRAC_LUT_Q23 = EXP2_FRAC_LUT0.cuda()
667
+
668
+ # saveTensor(f'zkdata/silu_q23_table.bin', EXP2_FRAC_LUT0.cpu())
669
+
670
+ def silu_q23(R: torch.Tensor, C: torch.Tensor):
671
+ global lib
672
+ global EXP2_FRAC_LUT_Q23
673
+
674
+ # print(EXP2_FRAC_LUT_Q23)
675
+ Bsz = R.shape[0]
676
+ S = R.shape[1]
677
+ Dim = R.shape[2]
678
+ lib.silu_q23(R.data_ptr(), C.data_ptr(), EXP2_FRAC_LUT_Q23.data_ptr(), Bsz, S, Dim)
679
+
680
+ def sigmoid_q23(R: torch.Tensor, C: torch.Tensor):
681
+ global lib
682
+ global EXP2_FRAC_LUT_Q23
683
+
684
+ Bsz = R.shape[0]
685
+ S = R.shape[1]
686
+ Dim = R.shape[2]
687
+ lib.sigmoid_q23(R.data_ptr(), C.data_ptr(), EXP2_FRAC_LUT_Q23.data_ptr(), Bsz, S, Dim)
688
+ # end of silu_q23 ---------------------------------
689
+
690
+
691
+ if __name__ == "__main__":
692
+ softmax_init_q21()
693
+
694
+ torch.manual_seed(0)
695
+ device = "cuda"
696
+
697
+ Bsz = 1
698
+ S = 1
699
+ H = 2
700
+ T = 10
701
+
702
+ A = torch.rand([Bsz, S, H, T], dtype=torch.bfloat16, device=device)
703
+ a = (A.to(torch.float32) * (2 ** 21)).to(torch.int64)
704
+ # a = (A * (2 ** 21)).to(torch.int64)
705
+
706
+ print('A: ' + str(A))
707
+ print('a: ' + str(a))
708
+
709
+ c = torch.zeros([Bsz, S, H, T], dtype=torch.int64, device=device)
710
+
711
+ softmax_q21(a, c)
712
+
713
+ r0 = A.softmax(dim=-1, dtype=torch.float32).type_as(A)
714
+ print('r0: ' + str(r0))
715
+
716
+ r1 = (c.to(torch.float32) * (2 ** -21)).to(torch.bfloat16)
717
+ print('r1: ' + str(r1))
718
+
719
+
720
+ R0 = (r0.to(torch.float32) * (2 ** 21)).to(torch.int64)
721
+ # R0 = (r0 * (2 ** 21)).to(torch.int64)
722
+ print('R0: ' + str(R0))
723
+
724
+ print('R1: ' + str(c))
inference/model.py ADDED
@@ -0,0 +1,1631 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import datetime
4
+ from dataclasses import dataclass
5
+ from typing import Tuple, Optional, Literal
6
+
7
+ import torch
8
+ from torch import nn
9
+ import torch.nn.functional as F
10
+ import torch.distributed as dist
11
+ from safetensors.torch import load_model
12
+
13
+ from kernel import act_quant, weight_dequant, fp8_gemm, int64_bmm_broadcast, \
14
+ complex_int64_mul_broadcast, einsum_bshd_hdc_bshc, einsum_bshc_btc_bsht, softmax_init_q21, softmax_q21, einsum_bsht_btc_bshc, einsum_bshc_hdc_bshd, \
15
+ silu_init_q25, silu_q25, sigmoid_q25, softmax_init_q19, softmax_q19, silu_init_q23, silu_q23, sigmoid_q23, RMS_Norm_int64
16
+
17
+
18
+ world_size = 1
19
+ rank = 0
20
+ block_size = 128
21
+ gemm_impl: Literal["bf16", "fp8"] = "bf16"
22
+ attn_impl: Literal["naive", "absorb"] = "absorb"
23
+
24
+ snark = False
25
+
26
+ @dataclass
27
+ class ModelArgs:
28
+ """
29
+ Data class for defining model arguments and hyperparameters.
30
+
31
+ Attributes:
32
+ max_batch_size (int): Maximum batch size.
33
+ max_seq_len (int): Maximum sequence length.
34
+ dtype (Literal["bf16", "fp8"]): Data type for computations.
35
+ vocab_size (int): Vocabulary size.
36
+ dim (int): Model dimension.
37
+ inter_dim (int): Intermediate dimension for MLP layers.
38
+ moe_inter_dim (int): Intermediate dimension for MoE layers.
39
+ n_layers (int): Number of transformer layers.
40
+ n_dense_layers (int): Number of dense layers in the model.
41
+ n_heads (int): Number of attention heads.
42
+ n_routed_experts (int): Number of routed experts for MoE layers.
43
+ n_shared_experts (int): Number of shared experts for MoE layers.
44
+ n_activated_experts (int): Number of activated experts in MoE layers.
45
+ n_expert_groups (int): Number of expert groups.
46
+ n_limited_groups (int): Number of limited groups for MoE routing.
47
+ score_func (Literal["softmax", "sigmoid"]): Scoring function for MoE routing.
48
+ route_scale (float): Scaling factor for routing scores.
49
+ q_lora_rank (int): LoRA rank for query projections.
50
+ kv_lora_rank (int): LoRA rank for key-value projections.
51
+ qk_nope_head_dim (int): Dimension for query-key projections without positional embeddings.
52
+ qk_rope_head_dim (int): Dimension for query-key projections with rotary embeddings.
53
+ v_head_dim (int): Dimension for value projections.
54
+ original_seq_len (int): Original sequence length.
55
+ rope_theta (float): Base for rotary positional encoding.
56
+ rope_factor (float): Scaling factor for extended sequence lengths.
57
+ beta_fast (int): Fast beta correction factor.
58
+ beta_slow (int): Slow beta correction factor.
59
+ mscale (float): Scaling factor for extended attention.
60
+ """
61
+ max_batch_size: int = 8
62
+ max_seq_len: int = 4096 * 4
63
+ dtype: Literal["bf16", "fp8"] = "bf16"
64
+ vocab_size: int = 102400
65
+ dim: int = 2048
66
+ inter_dim: int = 10944
67
+ moe_inter_dim: int = 1408
68
+ n_layers: int = 27
69
+ n_dense_layers: int = 1
70
+ n_heads: int = 16
71
+ # moe
72
+ n_routed_experts: int = 64
73
+ n_shared_experts: int = 2
74
+ n_activated_experts: int = 6
75
+ n_expert_groups: int = 1
76
+ n_limited_groups: int = 1
77
+ score_func: Literal["softmax", "sigmoid"] = "softmax"
78
+ route_scale: float = 1.
79
+ # mla
80
+ q_lora_rank: int = 0
81
+ kv_lora_rank: int = 512
82
+ qk_nope_head_dim: int = 128
83
+ qk_rope_head_dim: int = 64
84
+ v_head_dim: int = 128
85
+ # yarn
86
+ original_seq_len: int = 4096
87
+ rope_theta: float = 10000.0
88
+ rope_factor: float = 40
89
+ beta_fast: int = 32
90
+ beta_slow: int = 1
91
+ mscale: float = 1.
92
+
93
+ def saveTensor(fileName, t):
94
+ with open(fileName, "w", encoding="utf-8") as f:
95
+ t = t.detach()
96
+ if t.device.type != "cpu":
97
+ t = t.cpu()
98
+ t = t.contiguous()
99
+ with open(fileName, "wb") as f:
100
+ # .numpy() -> bytes(C-order)
101
+ f.write(t.numpy().tobytes(order="C"))
102
+
103
+ class ParallelEmbedding(nn.Module):
104
+ """
105
+ Embedding layer with parallelism support across distributed processes.
106
+
107
+ Args:
108
+ vocab_size (int): Vocabulary size.
109
+ dim (int): Embedding dimension.
110
+ """
111
+ def __init__(self, vocab_size: int, dim: int):
112
+ super().__init__()
113
+ self.vocab_size = vocab_size
114
+ self.dim = dim
115
+ assert vocab_size % world_size == 0, f"Vocabulary size must be divisible by world size (world_size={world_size})"
116
+ self.part_vocab_size = (vocab_size // world_size)
117
+ self.vocab_start_idx = rank * self.part_vocab_size
118
+ self.vocab_end_idx = self.vocab_start_idx + self.part_vocab_size
119
+ # weight 的 shape: [129280, 7168]
120
+ self.register_buffer("weight", torch.empty(self.part_vocab_size, self.dim, dtype=torch.int64))
121
+
122
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
123
+ """
124
+ Forward pass for parallel embedding layer.
125
+
126
+ Args:
127
+ x (torch.Tensor): Input tensor containing token indices.
128
+
129
+ Returns:
130
+ torch.Tensor: Embedded representations.
131
+
132
+ Raises:
133
+ ValueError: If `world_size` is not defined.
134
+ """
135
+ # print('aaab ' + str(self.weight[0][0].type()))
136
+ if world_size > 1:
137
+ # 找出 x 中 的值不在 [vocab_start_idx, vocab_end_idx) 范围内的下标
138
+ mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx)
139
+ # x 中所有的值都减去 vocab_start_idx
140
+ x = x - self.vocab_start_idx
141
+ # 之前找出的标记为 mask 下标的值设置为0
142
+ x[mask] = 0
143
+ y = F.embedding(x, self.weight)
144
+ if world_size > 1:
145
+ y[mask] = 0
146
+ dist.all_reduce(y)
147
+
148
+ # print(f'ParallelEmbedding x: {x}', flush=True)
149
+ return y
150
+
151
+
152
+ def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
153
+ """
154
+ Applies a linear transformation to the incoming data: y = xA^T + b.
155
+ This function supports specialized implementations based on quantization
156
+ and tensor formats.
157
+
158
+ Args:
159
+ x (torch.Tensor): The input tensor.
160
+ weight (torch.Tensor): The weight tensor. It may be quantized and
161
+ requires dequantization for certain cases.
162
+ bias (Optional[torch.Tensor]): The bias tensor to be added. Default is None.
163
+
164
+ Returns:
165
+ torch.Tensor: The result of the linear transformation, which may involve
166
+ quantization-aware computations depending on the input parameters.
167
+
168
+ Notes:
169
+ - If `weight` is quantized (e.g., `element_size() == 1`), a dequantized version
170
+ is used for computation.
171
+ - If `gemm_impl == "bf16"`, dequantization and a `bf16` GEMM operation are applied.
172
+ - For other cases, the function applies quantization to `x` and uses `fp8_gemm` for computation.
173
+ """
174
+
175
+ element_size = weight.element_size()
176
+ typ = weight.type()
177
+ # print(f'linear weight element_size {element_size}, type: {typ}', flush=True)
178
+ if weight.element_size() > 1:
179
+ # print('linear weight.element_size > 1, element_size=' + str(weight.element_size()), flush=True)
180
+ return F.linear(x, weight, bias)
181
+ elif gemm_impl == "bf16":
182
+ weight = weight_dequant(weight, weight.scale)
183
+ return F.linear(x, weight, bias)
184
+ else:
185
+ # print('linear act_quant', flush=True)
186
+ x, scale = act_quant(x, block_size)
187
+ y = fp8_gemm(x, scale, weight, weight.scale)
188
+ if bias is not None:
189
+ y += bias
190
+ return y
191
+
192
+ def linear_int(x: torch.Tensor, weight: torch.Tensor, x_rescale, weight_rescale, res_rescale, bias: Optional[torch.Tensor] = None) -> tuple[torch.Tensor]:
193
+ if weight.element_size() > 1:
194
+ (q, r) = int64_bmm_broadcast(x, weight, x_rescale, weight_rescale, res_rescale)
195
+ return (q, r)
196
+ elif gemm_impl == "bf16":
197
+ weight = weight_dequant(weight, weight.scale)
198
+
199
+ return (F.linear(x, weight, bias), torch.tensor(0, dtype=torch.int64))
200
+ else:
201
+ print('linear act_quant', flush=True)
202
+ x, scale = act_quant(x, block_size)
203
+ y = fp8_gemm(x, scale, weight, weight.scale)
204
+ if bias is not None:
205
+ y += bias
206
+ return (y, torch.tensor(0, dtype=torch.int64))
207
+
208
+ class Linear_int(nn.Module):
209
+ """
210
+ Custom linear layer with support for quantized weights and optional bias.
211
+
212
+ Args:
213
+ in_features (int): Number of input features.
214
+ out_features (int): Number of output features.
215
+ bias (bool): Whether to include a bias term. Defaults to False.
216
+ dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
217
+ """
218
+ dtype = torch.int64
219
+
220
+ def __init__(self, layer_id, in_features: int, out_features: int, x_rescale, weight_rescale, res_rescale, dtype, bias: bool = False):
221
+ super().__init__()
222
+ self.layer_id = layer_id
223
+ self.in_features = in_features
224
+ self.out_features = out_features
225
+
226
+ self.x_rescale = x_rescale
227
+ self.weight_rescale = weight_rescale
228
+ self.res_rescale = res_rescale
229
+
230
+ self.register_buffer("weight", torch.empty(out_features, in_features, dtype=dtype))
231
+
232
+ if bias:
233
+ self.bias = nn.Parameter(torch.empty(out_features))
234
+ else:
235
+ self.register_parameter("bias", None)
236
+
237
+ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor]:
238
+ q, r = linear_int(x, self.weight, self.x_rescale, self.weight_rescale, self.res_rescale, self.bias)
239
+ return q, r
240
+
241
+ class Linear_rescale_int(nn.Module):
242
+ """
243
+ Custom linear layer with support for quantized weights and optional bias.
244
+
245
+ Args:
246
+ in_features (int): Number of input features.
247
+ out_features (int): Number of output features.
248
+ bias (bool): Whether to include a bias term. Defaults to False.
249
+ dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
250
+ """
251
+ dtype = torch.int64
252
+
253
+ def __init__(self, layer_id, in_features: int, out_features: int, x_rescale, weight_rescale, dtype, bias: bool = False):
254
+ super().__init__()
255
+ self.layer_id = layer_id
256
+ self.in_features = in_features
257
+ self.out_features = out_features
258
+
259
+ self.x_rescale = x_rescale
260
+ self.weight_rescale = weight_rescale
261
+
262
+ self.register_buffer("weight", torch.empty(out_features, in_features, dtype=dtype))
263
+ self.register_buffer("scale", torch.tensor(0, dtype=torch.int32))
264
+
265
+ if bias:
266
+ self.bias = nn.Parameter(torch.empty(out_features))
267
+ else:
268
+ self.register_parameter("bias", None)
269
+
270
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
271
+ rescale = self.scale.item()
272
+ y, _r = linear_int(x, self.weight, self.x_rescale, self.weight_rescale, rescale, self.bias)
273
+ return y
274
+
275
+ class Linear(nn.Module):
276
+ """
277
+ Custom linear layer with support for quantized weights and optional bias.
278
+
279
+ Args:
280
+ in_features (int): Number of input features.
281
+ out_features (int): Number of output features.
282
+ bias (bool): Whether to include a bias term. Defaults to False.
283
+ dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
284
+ """
285
+ dtype = torch.bfloat16
286
+
287
+ def __init__(self, layer_id, in_features: int, out_features: int, bias: bool = False, dtype = None):
288
+ super().__init__()
289
+ self.layer_id = layer_id
290
+ self.in_features = in_features
291
+ self.out_features = out_features
292
+ self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype or Linear.dtype))
293
+
294
+ # print('Linear.weight.element_size: ' + str(self.weight.element_size()))
295
+
296
+ # nn.Parameter.element_size() 返回的是 每个元素在内存中占用的字节数
297
+ # torch.float32 -> 4 字节
298
+ # torch.float64 -> 8 字节
299
+ # torch.int64 -> 8 字节
300
+ # torch.bfloat16 -> 2 字节
301
+ # torch.float8_e4m3fn -> 1 字节
302
+ if self.weight.element_size() == 1:
303
+ scale_out_features = (out_features + block_size - 1) // block_size
304
+ scale_in_features = (in_features + block_size - 1) // block_size
305
+
306
+ self.weight.scale = self.scale = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=torch.float32))
307
+ else:
308
+ self.register_parameter("scale", None)
309
+
310
+ if bias:
311
+ self.bias = nn.Parameter(torch.empty(out_features))
312
+ else:
313
+ self.register_parameter("bias", None)
314
+
315
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
316
+ """
317
+ Forward pass for the custom linear layer.
318
+
319
+ Args:
320
+ x (torch.Tensor): Input tensor.
321
+
322
+ Returns:
323
+ torch.Tensor: Transformed tensor after linear computation.
324
+ """
325
+ return linear(x, self.weight, self.bias)
326
+
327
+
328
+ class ColumnParallelLinear(Linear):
329
+ """
330
+ Linear layer with column parallelism, splitting output features across distributed processes.
331
+
332
+ Args:
333
+ in_features (int): Number of input features.
334
+ out_features (int): Total number of output features.
335
+ bias (bool): Whether to include a bias term. Defaults to False.
336
+ dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
337
+ """
338
+ def __init__(self, layer_id, in_features: int, out_features: int, bias: bool = False, dtype = None):
339
+ assert out_features % world_size == 0, f"Output features must be divisible by world size (world_size={world_size})"
340
+ self.part_out_features = out_features // world_size
341
+ super().__init__(layer_id, in_features, self.part_out_features, bias, dtype)
342
+
343
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
344
+ """
345
+ Forward pass for column parallel linear layer.
346
+
347
+ Args:
348
+ x (torch.Tensor): Input tensor.
349
+
350
+ Returns:
351
+ torch.Tensor: Transformed tensor with column-parallel computation.
352
+ """
353
+ y = linear(x, self.weight, self.bias)
354
+ return y
355
+
356
+ class ColumnParallelLinear_int(Linear_int):
357
+ def __init__(self, layer_id, in_features: int, out_features: int, x_rescale, weight_rescale, res_rescale, dtype, bias: bool = False):
358
+ assert out_features % world_size == 0, f"Output features must be divisible by world size (world_size={world_size})"
359
+ self.part_out_features = out_features // world_size
360
+ super().__init__(layer_id, in_features, self.part_out_features, x_rescale, weight_rescale, res_rescale, dtype, bias)
361
+
362
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
363
+ y, _r = linear_int(x, self.weight, self.x_rescale, self.weight_rescale, self.res_rescale, self.bias)
364
+ return y
365
+
366
+ class ColumnParallelLinear_rescale_int(Linear_int):
367
+ def __init__(self, layer_id, in_features: int, out_features: int, x_rescale, weight_rescale, dtype, bias: bool = False):
368
+ assert out_features % world_size == 0, f"Output features must be divisible by world size (world_size={world_size})"
369
+ self.part_out_features = out_features // world_size
370
+ super().__init__(layer_id, in_features, self.part_out_features, x_rescale, weight_rescale, 1, dtype, bias)
371
+ self.register_buffer("scale", torch.tensor(0, dtype=torch.int32))
372
+ # self.res_rescale = self.scale
373
+
374
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
375
+ rescale = self.scale.item()
376
+ y, _r = linear_int(x, self.weight, self.x_rescale, self.weight_rescale, rescale, self.bias)
377
+ return y
378
+
379
+
380
+ class RowParallelLinear(Linear):
381
+ """
382
+ Linear layer with row parallelism, splitting input features across distributed processes.
383
+
384
+ Args:
385
+ in_features (int): Total number of input features.
386
+ out_features (int): Number of output features.
387
+ bias (bool): Whether to include a bias term. Defaults to False.
388
+ dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
389
+ """
390
+ def __init__(self, layer_id, in_features: int, out_features: int, bias: bool = False, dtype = None):
391
+ assert in_features % world_size == 0, f"Input features must be divisible by world size (world_size={world_size})"
392
+ self.part_in_features = in_features // world_size
393
+ super().__init__(layer_id, self.part_in_features, out_features, bias, dtype)
394
+
395
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
396
+ """
397
+ Forward pass for row parallel linear layer.
398
+
399
+ Args:
400
+ x (torch.Tensor): Input tensor.
401
+
402
+ Returns:
403
+ torch.Tensor: Transformed tensor with row-parallel computation.
404
+ """
405
+ y = linear(x, self.weight)
406
+ if world_size > 1:
407
+ dist.all_reduce(y)
408
+ if self.bias is not None:
409
+ y += self.bias
410
+ return y
411
+
412
+ class RowParallelLinear_rescale_int(Linear_int):
413
+ """
414
+ Linear layer with row parallelism, splitting input features across distributed processes.
415
+
416
+ Args:
417
+ in_features (int): Total number of input features.
418
+ out_features (int): Number of output features.
419
+ bias (bool): Whether to include a bias term. Defaults to False.
420
+ dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
421
+ """
422
+ def __init__(self, layer_id, in_features: int, out_features: int, x_rescale, weight_rescale, res_rescale, dtype, bias: bool = False):
423
+ assert in_features % world_size == 0, f"Input features must be divisible by world size (world_size={world_size})"
424
+ self.part_in_features = in_features // world_size
425
+ super().__init__(layer_id, self.part_in_features, out_features, x_rescale, weight_rescale, res_rescale, dtype, bias)
426
+ self.register_buffer("scale", torch.tensor(0, dtype=torch.int32))
427
+ self.res_rescale = self.scale # useless
428
+
429
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
430
+ """
431
+ Forward pass for row parallel linear layer.
432
+
433
+ Args:
434
+ x (torch.Tensor): Input tensor.
435
+
436
+ Returns:
437
+ torch.Tensor: Transformed tensor with row-parallel computation.
438
+ """
439
+ # rescale = 2 ** self.scale.item()
440
+ rescale = self.scale.item()
441
+ # print(f'RowParallelLinear_rescale_int forward scale: {self.scale} ' + str(rescale), flush=True)
442
+ y, _ = linear_int(x, self.weight, self.x_rescale, self.weight_rescale, rescale, self.bias)
443
+ if world_size > 1:
444
+ dist.all_reduce(y)
445
+ if self.bias is not None:
446
+ y += self.bias
447
+ return y
448
+
449
+
450
+ class RMSNorm(nn.Module):
451
+ """
452
+ Root Mean Square Layer Normalization (RMSNorm).
453
+
454
+ Args:
455
+ dim (int): Dimension of the input tensor.
456
+ eps (float): Epsilon value for numerical stability. Defaults to 1e-6.
457
+ """
458
+ def __init__(self, dim: int, eps: float = 1e-6):
459
+ super().__init__()
460
+ self.dim = dim
461
+ self.eps = eps
462
+ self.weight = nn.Parameter(torch.ones(dim))
463
+
464
+ def forward(self, x: torch.Tensor):
465
+ """
466
+ Forward pass for RMSNorm.
467
+
468
+ Args:
469
+ x (torch.Tensor): Input tensor.
470
+
471
+ Returns:
472
+ torch.Tensor: Normalized tensor with the same shape as input.
473
+ """
474
+ return F.rms_norm(x, (self.dim,), self.weight, self.eps)
475
+
476
+ class RMSNorm_int(nn.Module):
477
+ def __init__(self, dim: int, dtype, eps: float = 1e-6):
478
+ super().__init__()
479
+ self.dim = dim
480
+ self.eps = eps
481
+ self.register_buffer(
482
+ "weight",
483
+ torch.ones(dim, dtype=dtype))
484
+
485
+ def forward(self, x: torch.Tensor):
486
+ # x 的 scale 为 2 ** 31
487
+ # weight的scale 为 2 ** 15, 范围为 2^7 - 2^14
488
+ # rms 的 scale 为 2 ** 28
489
+ # 返回的结果 scale 为 2 ** 16,因为中间计算的时候 除以了 (1 << 15),44 + 15 - 28 - 15 = 16
490
+ (c, rms) = RMS_Norm_int64(x[0], self.weight, 1, self.dim)
491
+
492
+ return (c[None, :], rms)
493
+
494
+
495
+ def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor:
496
+ """
497
+ Precomputes frequency-based complex exponential values for rotary positional embeddings.
498
+
499
+ Args:
500
+ args (ModelArgs): Model arguments containing positional embedding parameters.
501
+
502
+ Returns:
503
+ torch.Tensor: Precomputed complex exponential values for positional embeddings.
504
+ """
505
+ # dim = 64
506
+ dim = args.qk_rope_head_dim
507
+ seqlen = args.max_seq_len
508
+ beta_fast = args.beta_fast
509
+ beta_slow = args.beta_slow
510
+ base = args.rope_theta
511
+ factor = args.rope_factor
512
+
513
+ def find_correction_dim(num_rotations, dim, base, max_seq_len):
514
+ """
515
+ Computes the correction dimension for a given number of rotations in the rotary positional embedding.
516
+
517
+ Args:
518
+ num_rotations (float): Number of rotations to compute the correction for.
519
+ dim (int): Dimensionality of the embedding space.
520
+ base (float): Base value for the exponential computation.
521
+ max_seq_len (int): Maximum sequence length.
522
+
523
+ Returns:
524
+ float: The correction dimension based on the input parameters.
525
+ """
526
+ return dim * math.log(max_seq_len / (num_rotations * 2 * math.pi)) / (2 * math.log(base))
527
+
528
+ def find_correction_range(low_rot, high_rot, dim, base, max_seq_len):
529
+ """
530
+ Computes the range of correction dimensions for rotary positional embeddings.
531
+
532
+ Args:
533
+ low_rot (float): Lower bound for the number of rotations.
534
+ high_rot (float): Upper bound for the number of rotations.
535
+ dim (int): Dimensionality of the embedding space.
536
+ base (float): Base value for the exponential computation.
537
+ max_seq_len (int): Maximum sequence length.
538
+
539
+ Returns:
540
+ Tuple[int, int]: The range of correction dimensions (low, high), clamped to valid indices.
541
+ """
542
+ low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len))
543
+ high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len))
544
+ return max(low, 0), min(high, dim-1)
545
+
546
+ def linear_ramp_factor(min, max, dim):
547
+ """
548
+ Computes a linear ramp function used to smooth values between a minimum and maximum range.
549
+
550
+ Args:
551
+ min (float): Minimum value for the ramp function.
552
+ max (float): Maximum value for the ramp function.
553
+ dim (int): Dimensionality of the ramp tensor.
554
+
555
+ Returns:
556
+ torch.Tensor: A tensor of shape (dim,) with values linearly interpolated between 0 and 1,
557
+ clamped to the range [0, 1].
558
+ """
559
+ if min == max:
560
+ max += 0.001
561
+ linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
562
+ ramp_func = torch.clamp(linear_func, 0, 1)
563
+ return ramp_func
564
+
565
+ # torch.arange(0, dim, 2, dtype=torch.float32) 的作用是: 生成从 0 开始、步长为 2、到 dim 之前(不含 dim)的一维张量,数据类型为 float32
566
+ # 1/10000^(2k/d_model)
567
+ # freqs shape: 一维向量,长度为 dim /2
568
+ freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
569
+ # original_seq_len=4096
570
+ if seqlen > args.original_seq_len:
571
+ low, high = find_correction_range(beta_fast, beta_slow, dim, base, args.original_seq_len)
572
+ smooth = 1 - linear_ramp_factor(low, high, dim // 2)
573
+ freqs = freqs / factor * (1 - smooth) + freqs * smooth
574
+
575
+ t = torch.arange(seqlen)
576
+ # torch.outer 的作用是计算两个向量的 外积 (outer product),比如:
577
+ # t = torch.tensor([1, 2, 3]) # shape = [3]
578
+ # freqs = torch.tensor([10, 20]) # shape = [2]
579
+ # out = torch.outer(t, freqs)
580
+ # tensor([[10, 20],
581
+ # [20, 40],
582
+ # [30, 60]])
583
+ # freqs shape为 [seqlen, dim/2]
584
+ freqs = torch.outer(t, freqs)
585
+ # torch.polar(abs, angle) 的作用: 把 极坐标 (r, θ) 转换成 复数 (x + iy) 的函数
586
+ # freqs_cis_0 shape为 [seqlen, dim/2]
587
+ freqs_cis_0 = torch.polar(torch.ones_like(freqs), freqs)
588
+
589
+ # return freqs_cis_0
590
+
591
+ # 复数转换成实数, freqs_cis_1 shape为 [seqlen, dim]
592
+ freqs_cis_1 = torch.view_as_real(freqs_cis_0)
593
+
594
+ # freqs_cis = torch.empty_like(freqs_cis_1, dtype=torch.int64, device='cuda')
595
+
596
+ # cols 为 2 * freqs_cis_1.shape[1] 是因为 复数的实部 和 虚部
597
+ # rescale 参数为 19 = 42 - 23, ex 部分加 +19,总的rescale为 2^42
598
+ freqs_cis = (freqs_cis_1 * (2 ** 42)).round().to(torch.int64)
599
+
600
+ freqs_cis_abs = freqs_cis.abs()
601
+ min1 = freqs_cis_abs.min()
602
+ max1 = freqs_cis_abs.max()
603
+ print(f'freqs_cis min {min1}, max: {max1}', flush=True)
604
+
605
+ # print(f'freqs_cis: {freqs_cis}')
606
+ # freqs_cis 的 rescale 为 2^42
607
+ return freqs_cis
608
+
609
+ # x(q_pe) 的维度 [batch, seqLen, 128, 64]
610
+ def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
611
+ """
612
+ Applies rotary positional embeddings to the input tensor.
613
+
614
+ Args:
615
+ x (torch.Tensor): Input tensor with positional embeddings to be applied.
616
+ freqs_cis (torch.Tensor): Precomputed complex exponential values for positional embeddings.
617
+
618
+ Returns:
619
+ torch.Tensor: Tensor with rotary embeddings applied.
620
+ """
621
+
622
+ # if x.dtype == torch.int64:
623
+ # x 的维度 变为 [batch, seqLen, 128, 32, 2]
624
+ ### important!!! 调用 so lib库之前,必须确保内存连续
625
+ x = x.contiguous().view(*x.shape[:-1], -1, 2)
626
+ # freqs_cis 的维度为 [1, seqLen, 1, 32, 2]
627
+ freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-2), 2)
628
+ # freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1))
629
+ # 4194304 = 1 << (64 - 42), 42是 rescale, int64 * int64 结果的高 64位 乘以 4194304
630
+ # 4398046511104 = 1 << 42
631
+ # print(x)
632
+ # print(f'x shape: {x.shape}, freqs_cis shape: {freqs_cis.shape}')
633
+ # y = complex_int64_mul_broadcast(x, freqs_cis, 4194304, 4398046511104)
634
+ y = complex_int64_mul_broadcast(x, freqs_cis)
635
+ y2 = y.flatten(3)
636
+ return y2
637
+
638
+
639
+ def getBF16PrintStr(ele):
640
+ v = int(ele.cpu().view(torch.uint16).item())
641
+ ex = v >> 7 & 0xFF
642
+ r = '(1+' + str(v & 0x7F) + '/128)'
643
+ rraw = v & 0x7F
644
+
645
+ if v & 0x8000:
646
+ vstr = '-' + r + '*2^' + str(ex - 127)
647
+ else:
648
+ vstr = r + '*2^' + str(ex - 127)
649
+ return vstr
650
+
651
+ class MLA(nn.Module):
652
+ """
653
+ Multi-Headed Attention Layer (MLA).
654
+
655
+ Attributes:
656
+ dim (int): Dimensionality of the input features.
657
+ n_heads (int): Number of attention heads.
658
+ n_local_heads (int): Number of local attention heads for distributed systems.
659
+ q_lora_rank (int): Rank for low-rank query projection.
660
+ kv_lora_rank (int): Rank for low-rank key/value projection.
661
+ qk_nope_head_dim (int): Dimensionality of non-positional query/key projections.
662
+ qk_rope_head_dim (int): Dimensionality of rotary-positional query/key projections.
663
+ qk_head_dim (int): Total dimensionality of query/key projections.
664
+ v_head_dim (int): Dimensionality of value projections.
665
+ softmax_scale (float): Scaling factor for softmax in attention computation.
666
+ """
667
+ def __init__(self, layer_id, args: ModelArgs):
668
+ super().__init__()
669
+
670
+ # RowParallelLinear和ColumnParallelLinear是将Linear层按照行和列划分为多个子线性层并分配到各个设备上,每个设备维护一个子线性层,
671
+ # 如线性层的shape为[in_features, out_features],RowParallelLinear的shape为[in_features/world_size, out_features],
672
+ # ColumnParallelLinear的shape为[in_features,out_features/world_size],world_size是设备数
673
+
674
+ self.layer_id = layer_id
675
+
676
+ # 7168
677
+ self.dim = args.dim
678
+ # 128
679
+ self.n_heads = args.n_heads
680
+ # 当前进程跑的header数目
681
+ self.n_local_heads = args.n_heads // world_size
682
+ # query向下投影矩阵维度,默认为0表示不压缩,实际使用过程为 1536
683
+ self.q_lora_rank = args.q_lora_rank
684
+ # key和value向下投影矩阵维度,实际使用过程为 512;
685
+ self.kv_lora_rank = args.kv_lora_rank
686
+ # query/key不包含位置信息的隐藏层维度, 实际使用过程为 128
687
+ self.qk_nope_head_dim = args.qk_nope_head_dim
688
+ # query/key包含rope位置信息的隐藏层维度, 实际使用过程为 64
689
+ self.qk_rope_head_dim = args.qk_rope_head_dim
690
+
691
+ # 192
692
+ self.qk_head_dim = args.qk_nope_head_dim + args.qk_rope_head_dim
693
+ # value隐藏层维度, 实际使用过程为 128
694
+ self.v_head_dim = args.v_head_dim
695
+
696
+ # query向下投影矩阵维度,默认为0表示不压缩,实际使用过程为 1536
697
+ if self.q_lora_rank == 0:
698
+ self.wq = ColumnParallelLinear(layer_id, self.dim, self.n_heads * self.qk_head_dim)
699
+ else:
700
+ # query向下投影矩阵, shape [7168, 1536], Float8_e4m3fnTensor
701
+ self.wq_a = Linear_int(layer_id, self.dim, self.q_lora_rank, 1, 1, 30, torch.int32)
702
+ self.q_norm = RMSNorm_int(self.q_lora_rank, torch.int32)
703
+ # query向上投影矩阵的列并行线性层, shape [1536, 24576(128 * 192)], Float8_e4m3fnTensor
704
+ # self.wq_b = ColumnParallelLinear_int(layer_id, self.q_lora_rank, self.n_heads * self.qk_head_dim, 1, 1, (1 << 30), torch.int32)
705
+ self.wq_b1 = ColumnParallelLinear_int(layer_id, self.q_lora_rank, self.n_heads * args.qk_nope_head_dim, 1, 1, 30, torch.int32)
706
+ self.wq_b2 = ColumnParallelLinear_int(layer_id, self.q_lora_rank, self.n_heads * args.qk_rope_head_dim, 1, 1, 30, torch.int32)
707
+
708
+ # key和value的向下投影矩阵, shape [576, 7168], Float8_e4m3fnTensor, kv_lora_rank=512, qk_rope_head_dim=64
709
+ # self.wkv_a = Linear_int(layer_id, self.dim, self.kv_lora_rank + self.qk_rope_head_dim, 1, 1, (1 << 29), torch.int32)
710
+ self.wkv_a1 = Linear_int(layer_id, self.dim, self.kv_lora_rank, 1, 1, 29, torch.int32)
711
+ self.wkv_a2 = Linear_int(layer_id, self.dim, self.qk_rope_head_dim, 1, 1, 29, torch.int32)
712
+ # self.kv_norm = RMSNorm(self.kv_lora_rank)
713
+ self.kv_norm = RMSNorm_int(self.kv_lora_rank, torch.int32)
714
+ # key和value向上投影矩阵的列并行线性层, shape [32768, 512], Float8_e4m3fnTensor
715
+ # kv_lora_rank=512, n_heads = 128, qk_nope_head_dim = 128, v_head_dim = 128
716
+ # self.wkv_b = ColumnParallelLinear(layer_id, self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim))
717
+ self.wkv_b_1 = ColumnParallelLinear_rescale_int(layer_id, self.kv_lora_rank, self.n_heads * self.qk_nope_head_dim, 1, 1, torch.int32)
718
+ self.wkv_b_2 = ColumnParallelLinear_rescale_int(layer_id, self.kv_lora_rank, self.n_heads * self.v_head_dim, 1, 1, torch.int32)
719
+
720
+ # 输出投影行并行线性层, shape [7168, 16384], Float8_e4m3fnTensor
721
+ self.wo = RowParallelLinear_rescale_int(layer_id, self.n_heads * self.v_head_dim, self.dim, 1, 1, 1, torch.int32)
722
+ # softmax缩放系数, qk_head_dim = 192
723
+ # self.softmax_scale = self.qk_head_dim ** -0.5
724
+ # # max_seq_len = 4096 * 4, original_seq_len = 4096
725
+ # if args.max_seq_len > args.original_seq_len:
726
+ # # mscale = 1.0, rope_factor = 40, math.log = ln 自然对数
727
+ # mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0
728
+ # self.softmax_scale = self.softmax_scale * mscale * mscale
729
+ self.softmax_scale1 = 94
730
+ self.softmax_scale2 = 695
731
+
732
+ if attn_impl == "naive":
733
+ self.register_buffer("k_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.qk_head_dim), persistent=False)
734
+ self.register_buffer("v_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.v_head_dim), persistent=False)
735
+ else:
736
+ # 缓存key和value向下投影表示
737
+ # self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank), persistent=False)
738
+ # self.register_buffer("kv_cache", torch.zeros(1, args.max_seq_len, self.kv_lora_rank), persistent=False)
739
+ self.register_buffer("kv_cache", torch.zeros(1, args.max_seq_len, self.kv_lora_rank, dtype=torch.int64), persistent=False)
740
+ # 缓存key执行rope操作后的表示
741
+ # self.register_buffer("pe_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim), persistent=False)
742
+ # self.register_buffer("pe_cache", torch.zeros(1, args.max_seq_len, self.qk_rope_head_dim), persistent=False)
743
+ self.register_buffer("pe_cache", torch.zeros(1, args.max_seq_len, self.qk_rope_head_dim, dtype=torch.int64), persistent=False)
744
+
745
+ # x shape [1, seqLen, 7168], x 的resacle 为 2^21
746
+ def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
747
+ """
748
+ Forward pass for the Multi-Headed Attention Layer (MLA).
749
+
750
+ Args:
751
+ x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim).
752
+ start_pos (int): Starting position in the sequence for caching.
753
+ freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
754
+ mask (Optional[torch.Tensor]): Mask tensor to exclude certain positions from attention.
755
+
756
+ Returns:
757
+ torch.Tensor: Output tensor with the same shape as the input.
758
+ """
759
+
760
+ # 从输入获取batch size和序列长度seqlen,并根据输入序列的起始位置计算输入序列的结束位置end_pos=start_pos+seqlen;
761
+ bsz, seqlen, _ = x.size()
762
+ end_pos = start_pos + seqlen
763
+
764
+ # 获取query的投影表示:如果对query投影矩阵进行压缩(即q_lora_rank不为0),则将输入乘以query的向下投影矩阵wq_a,然后经过归一化层q_norm,
765
+ # 再乘以向上投影矩阵wq_b,否则直接乘以原始投影矩阵wq;将其维度调整为[batchsize, n_local_threads, qk_head_dim];
766
+ if self.q_lora_rank == 0:
767
+ q = self.wq(x)
768
+ else:
769
+ # query向下投影矩阵, shape [7168, 1536], Float8_e4m3fnTensor
770
+ # x(也就是 attn_normed) 的 scale 为 2^21, wq_a weight 的scale 为 2^30, q_down 的 scale 为 2^21
771
+ q_down, q_down_rem = self.wq_a(x)
772
+ # q_down = self.wq_a(x)
773
+
774
+ if snark:
775
+ dirStr = f'zkdata/pos_{start_pos}/layer_{self.layer_id}'
776
+ os.makedirs(dirStr, exist_ok=True)
777
+ saveTensor(f'{dirStr}/wq_a_x.bin', x.cpu())
778
+ saveTensor(f'{dirStr}/wq_a_w.bin', self.wq_a.weight.view(torch.uint32).cpu())
779
+ saveTensor(f'{dirStr}/wq_a_y.bin', q_down.cpu())
780
+ saveTensor(f'{dirStr}/q_norm_r.bin', q_down_rem.cpu())
781
+ # q_down = (q_down.detach().to(torch.float32) * (2 ** -23)).to(torch.bfloat16)
782
+
783
+ # q_norm 的 rescale 为 2^19
784
+ (q_normed, rms) = self.q_norm(q_down)
785
+
786
+ if snark:
787
+ dirStr = f'zkdata/pos_{start_pos}/layer_{self.layer_id}'
788
+ os.makedirs(dirStr, exist_ok=True)
789
+ saveTensor(f'{dirStr}/q_norm_x.bin', q_down.cpu())
790
+ saveTensor(f'{dirStr}/q_norm_weight.bin', self.q_norm.weight.view(torch.uint32).cpu())
791
+ saveTensor(f'{dirStr}/q_norm_rms.bin', rms.cpu())
792
+ saveTensor(f'{dirStr}/q_norm_y.bin', q_normed.cpu())
793
+
794
+ # q 的 rescale 为 2^19
795
+ # q = self.wq_b(q_normed)
796
+ q_nope = self.wq_b1(q_normed)
797
+ q_pe = self.wq_b2(q_normed)
798
+
799
+ # 在pytorch中view函数的作用为重构张量的维度
800
+ # q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim)
801
+ q_nope = q_nope.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim)
802
+ q_pe = q_pe.view(bsz, seqlen, self.n_local_heads, self.qk_rope_head_dim)
803
+
804
+ # 将query的投影表示按照最后一个维度拆分,前面qk_nope_head_dim维(128)作为query不包含位置信息的表示q_nope,后面qk_rope_head_dim维(64)添加rope位置信息
805
+ # (调用apply_rotary_emb函数,参考秀才经商:DeepSeek源码解析之RoPE)作为query包含位置信息的表示q_pe(即公式39);
806
+ # q_nope 的维度[batch, seqLen, 128, 128], q_pe 的维度 [batch, seqLen, 128, 64]
807
+ # q_nope, q_pe 的 rescale 为 2^19
808
+ # q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
809
+ # freqs_cis 的 rescale 为 2^42, 计算之后 q_pe 的 rescale 为 2^19
810
+
811
+ if snark:
812
+ saveTensor(f'zkdata/pos_{start_pos}/layer_{self.layer_id}/q_pe_x.bin', q_pe.cpu())
813
+ saveTensor(f'zkdata/freqs_cis.bin', freqs_cis.cpu())
814
+
815
+ q_pe = apply_rotary_emb(q_pe, freqs_cis)
816
+
817
+ if snark:
818
+ saveTensor(f'zkdata/pos_{start_pos}/layer_{self.layer_id}/q_pe_y.bin', self.q_norm.weight.view(torch.uint32).cpu())
819
+
820
+ # 获取key和value的联合表示kv(即公式41中的)和包含位置信息的key表示k_pe(即公式43中的):输入乘以向下投影矩阵wkv_a后,按照最后一个维度拆分,
821
+ # 前面kv_lora_rank维作为key和value的联合表示,后面qk_rope_head_dim维添加rope位置信息(调用apply_rotary_emb)后得到包含rope位置信息的key表示;
822
+
823
+ # x 的resacle 为 2^21, kv shape [batch, seqLen, 512], kv 的resacle 为 2^21
824
+ kv, kv_rem = self.wkv_a1(x)
825
+
826
+ if snark:
827
+ dirStr = f'zkdata/pos_{start_pos}/layer_{self.layer_id}'
828
+ os.makedirs(dirStr, exist_ok=True)
829
+ saveTensor(f'{dirStr}/wkv_a1_x.bin', x.cpu())
830
+ saveTensor(f'{dirStr}/wkv_a1_w.bin', self.wkv_a1.weight.view(torch.uint32).cpu())
831
+ saveTensor(f'{dirStr}/wkv_a1_y.bin', kv.cpu())
832
+ saveTensor(f'{dirStr}/wkv_a1_r.bin', kv_rem.cpu())
833
+
834
+ k_pe, _ = self.wkv_a2(x)
835
+
836
+ # print(f'k_pe 1 shape: {k_pe.shape}', flush=True)
837
+ # unsqueeze()用于增加一个维度, k_pe.unsqueeze(2) 把 k_pe reshape 成 [batch, seqLen, 1, dim]
838
+ # # kv, k_pe 的resacle 为 2^21
839
+ k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis)
840
+ # print(f'k_pe 2 shape: {k_pe.shape}', flush=True)
841
+
842
+ if attn_impl == "naive":
843
+ q = torch.cat([q_nope, q_pe], dim=-1)
844
+ kv = self.wkv_b(self.kv_norm(kv))
845
+ kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim)
846
+ k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
847
+ k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)
848
+ self.k_cache[:bsz, start_pos:end_pos] = k
849
+ self.v_cache[:bsz, start_pos:end_pos] = v
850
+ scores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bsz, :end_pos]) * self.softmax_scale
851
+ else:
852
+ # 计算query和key的注意力:
853
+ # query中不包含位置信息的q_nope(乘以了key的向上投影矩阵后)与缓存kv_cache中的key表示求内积;
854
+ # query中包含位置信息的q_pe与缓存pe_cache中的key表示求内积;
855
+ # 两者相加后乘以softmax缩放系数softmax_scale
856
+
857
+ # q_nope 的维度[batch, seqLen, 128, 128], wkv_b_1 shape: [128, 128, 512]
858
+ # q_nope rescale 2^19, wkv_b_1 rescale 2 ** 32
859
+ # q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b_1)
860
+ # 调用 einsum_bshd_hdc_bshc 之后, q_nope维度 [batch, seqLen, 128, 512]
861
+ wkv_b_1 = self.wkv_b_1.weight.view(self.n_local_heads, -1, self.kv_lora_rank)
862
+ q_nope = einsum_bshd_hdc_bshc(q_nope.contiguous(), wkv_b_1.contiguous(), self.wkv_b_1.scale.item())
863
+ # print('q_nope type: ' + str(q_nope.type()))
864
+ # print('q_nope shape: ' + str(q_nope.shape))
865
+
866
+ # kv_normed 的 rescale 为 2^23
867
+ (kv_normed, rms) = self.kv_norm(kv)
868
+
869
+ # kv_cache 的 rescale 为 2^23, shape [batch, seqLen, 512],
870
+ self.kv_cache[:bsz, start_pos:end_pos] = kv_normed
871
+ # self.kv_cache[:bsz, start_pos:end_pos] = kv2
872
+
873
+ # kv = (kv.detach().to(torch.float32) * (2 ** -23)).to(torch.bfloat16)
874
+ # pe_cache 的 rescale 为 2^21
875
+ self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)
876
+
877
+ # q_nope rescale: 2^19, kv_cache rescale: 2^23
878
+ # q_nope 的维度 [batch, seqLen, 128, 512], kv_cache 维度 (batch, args.max_seq_len, 512)
879
+ # score1 = torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos])
880
+ kv_cache1 = self.kv_cache[:bsz, :end_pos]
881
+ # score1 = einsum_bshc_btc_bsht(q_nope.contiguous(), kv_cache1.contiguous(), 25)
882
+ # score1 的 rescale 为 2^19
883
+ score1 = einsum_bshc_btc_bsht(q_nope.contiguous(), kv_cache1.contiguous(), 23)
884
+ # print(f'kv_cache1 type: {kv_cache1.type()}, shape: {kv_cache1.shape}', flush=True)
885
+ # score1 = (score1.detach().to(torch.float32) * (2 ** -21)).to(torch.bfloat16)
886
+
887
+ # score2 = torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])
888
+ pe_cache1 = self.pe_cache[:bsz, :end_pos]
889
+ # score2 = einsum_bshc_btc_bsht(q_pe.contiguous(), pe_cache1.contiguous(), 23)
890
+ # q_pe 的 rescale 为 2^19, scores2 的rescale 为 2^19
891
+ score2 = einsum_bshc_btc_bsht(q_pe.contiguous(), pe_cache1.contiguous(), 21)
892
+ # score2 = (score2.detach().to(torch.float32) * (2 ** -21)).to(torch.bfloat16)
893
+
894
+ # scores = (score1 + score2) * self.softmax_scale
895
+ # scores 的 rescale 为 2 ** 19
896
+ scores = (score1 + score2) * self.softmax_scale1 // self.softmax_scale2
897
+ # scores = torch.round(((score1 + score2) * self.softmax_scale1).to(torch.float32) / self.softmax_scale2).to(torch.int64)
898
+
899
+
900
+ # mask 在 unsqueeze(1) 之后的 shape 为 [seqLen, 1, senLen], scores 的shape 为 [batch, seqLen, heads , t]
901
+ if mask is not None:
902
+ # print('mask type: ' + str(mask.type()))
903
+ # print('mask shape: ' + str(mask.shape))
904
+ scores += mask.unsqueeze(1)
905
+ # query和key的内积按照最后一个维度计算softmax值;
906
+ # scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x)
907
+ scores_new = torch.empty_like(scores, dtype=torch.int64, device='cuda')
908
+ # scores 和 scores_new 的 rescale 为 2 ** 19, shape: [bsz, seqLen, headCount, seqLen]
909
+
910
+ # # softmax_q19 会破坏 scores 的原始数据,先拷贝一份数据
911
+ if snark:
912
+ saveTensor(f'zkdata/pos_{start_pos}/layer_{self.layer_id}/scores_softmax_x.bin', scores.contiguous().cpu())
913
+
914
+ softmax_q19(scores.contiguous(), scores_new)
915
+
916
+ if snark:
917
+ saveTensor(f'zkdata/pos_{start_pos}/layer_{self.layer_id}/scores_softmax_y.bin', scores_new.cpu())
918
+
919
+ if attn_impl == "naive":
920
+ x = torch.einsum("bsht,bthd->bshd", scores, self.v_cache[:bsz, :end_pos])
921
+ else:
922
+
923
+ kv_cache2 = self.kv_cache[:bsz, :end_pos]
924
+ # kv_cache2 = (kv_cache2.detach().to(torch.float32) * (2 ** -25)).to(torch.bfloat16)
925
+
926
+ # x = (x.detach().to(torch.float32) * (2 ** -23)).to(torch.bfloat16)
927
+
928
+ # 计算最终输出:
929
+ # 注意力分数乘以kv缓存后,再乘以value的向上投影矩阵wkv_b(实现公式45和46);
930
+ # 乘以输出投影矩阵wo(公式47);
931
+ # x = torch.einsum("bsht,btc->bshc", scores_new, kv_cache2)
932
+ # scores_new 的 rescale 为 2^19, kv_cache2 的 rescale 为 2^23, bshc 的 rescale 为 2^19
933
+ # scores_new shape: [1, 8, 128, 8], bshc shape: [1, 8, 128, 512]
934
+ # bshc = einsum_bsht_btc_bshc(scores_new.contiguous(), kv_cache2.contiguous(), 25)
935
+ bshc = einsum_bsht_btc_bshc(scores_new.contiguous(), kv_cache2.contiguous(), 23)
936
+
937
+ # # v_head_dim = 128, kv_lora_rank = 512, n_local_heads = 128
938
+ # wkv_b_2 = wkv_b[:, -self.v_head_dim:]
939
+ # # print('wkv_b 2 type: ' + str(wkv_b_2.type()))
940
+ # # print('wkv_b 2 shape: ' + str(wkv_b_2.shape))
941
+ wkv_b_2 = self.wkv_b_2.weight
942
+ wkv_b_2 = wkv_b_2.view(self.n_local_heads, -1, self.kv_lora_rank)
943
+
944
+ # wkv_b_2 = (wkv_b_2.detach().to(torch.float32) * (2 ** -self.wkv_b_2.scale.item())).to(torch.bfloat16)
945
+
946
+ # x = torch.einsum("bshc,hdc->bshd", x, wkv_b_2)
947
+ # bshc 的 rescale 为 2^19, wkv_b_2 的 rescale 为 self.wkv_b_2.scale
948
+ # x 的 rescale 为 2 ** 19
949
+ # bshc shape: [1, seqLen, 128, 512], wkv_b_2 shape: [128, 128, 512]
950
+ x = einsum_bshc_hdc_bshd(bshc.contiguous(), wkv_b_2.contiguous(), self.wkv_b_2.scale.item())
951
+ # x = (x.detach().to(torch.float32) * (2 ** -21)).to(torch.bfloat16)
952
+
953
+ # x 返回的的 shape [1, seqLen, 7168]
954
+ x = self.wo(x.flatten(2))
955
+
956
+ return x
957
+
958
+ class MLP(nn.Module):
959
+ """
960
+ Multi-Layer Perceptron (MLP) used as a feed-forward layer.
961
+
962
+ Attributes:
963
+ w1 (nn.Module): Linear layer for input-to-hidden transformation.
964
+ w2 (nn.Module): Linear layer for hidden-to-output transformation.
965
+ w3 (nn.Module): Additional linear layer for feature transformation.
966
+ """
967
+ def __init__(self, layer_id, dim: int, inter_dim: int):
968
+ """
969
+ Initializes the MLP layer.
970
+
971
+ Args:
972
+ dim (int): Input and output dimensionality.
973
+ inter_dim (int): Hidden layer dimensionality.
974
+ """
975
+ super().__init__()
976
+ self.w1 = ColumnParallelLinear(layer_id, dim, inter_dim)
977
+ self.w2 = RowParallelLinear(layer_id, inter_dim, dim)
978
+ self.w3 = ColumnParallelLinear(layer_id, dim, inter_dim)
979
+
980
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
981
+ """
982
+ Forward pass for the MLP layer.
983
+
984
+ Args:
985
+ x (torch.Tensor): Input tensor.
986
+
987
+ Returns:
988
+ torch.Tensor: Output tensor after MLP computation.
989
+ """
990
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
991
+
992
+
993
+ class MLP_int(nn.Module):
994
+ """
995
+ Multi-Layer Perceptron (MLP) used as a feed-forward layer.
996
+
997
+ Attributes:
998
+ w1 (nn.Module): Linear layer for input-to-hidden transformation.
999
+ w2 (nn.Module): Linear layer for hidden-to-output transformation.
1000
+ w3 (nn.Module): Additional linear layer for feature transformation.
1001
+ """
1002
+ def __init__(self, layer_id, dim: int, inter_dim: int):
1003
+ """
1004
+ Initializes the MLP layer.
1005
+
1006
+ Args:
1007
+ dim (int): Input and output dimensionality.
1008
+ inter_dim (int): Hidden layer dimensionality.
1009
+ """
1010
+ super().__init__()
1011
+ self.layer_id = layer_id
1012
+ self.w1 = ColumnParallelLinear_rescale_int(layer_id, dim, inter_dim, 1, 1, torch.int32)
1013
+ self.w2 = RowParallelLinear_rescale_int(layer_id, inter_dim, dim, 1, 1, 1, torch.int32)
1014
+ self.w3 = ColumnParallelLinear_rescale_int(layer_id, dim, inter_dim, 1, 1, torch.int32)
1015
+
1016
+ # 输入的 x 的rescale 为 2^23, [bsz, seqLen, 7168]
1017
+ def forward(self, start_pos: int, x: torch.Tensor) -> torch.Tensor:
1018
+ """
1019
+ Forward pass for the MLP layer.
1020
+
1021
+ Args:
1022
+ x (torch.Tensor): Input tensor.
1023
+
1024
+ Returns:
1025
+ torch.Tensor: Output tensor after MLP computation.
1026
+ """
1027
+ # r1 shape: [bsz, seqLen, inter_dim], r1 rescale: 2^23
1028
+ r1 = self.w1(x)
1029
+
1030
+ # s1 = F.silu(r1)
1031
+ # s1 shape: [bsz, seqLen, inter_dim], s1 rescale: 2^23
1032
+ s1 = torch.empty_like(r1, dtype=torch.int64, device='cuda')
1033
+ # silu_q25(r1, s1)
1034
+
1035
+ if snark:
1036
+ saveTensor(f'zkdata/pos_{start_pos}/layer_{self.layer_id}/mlp_silu_x.bin', r1.contiguous().cpu())
1037
+
1038
+ silu_q23(r1, s1)
1039
+
1040
+ if snark:
1041
+ saveTensor(f'zkdata/pos_{start_pos}/layer_{self.layer_id}/mlp_silu_y.bin', s1.cpu())
1042
+
1043
+ # r2 rescale: 2^23, shape: [1, seqLen, inter_dim]
1044
+ r2 = self.w3(x)
1045
+
1046
+ # 返回的 shape [bsz, seqLen, dim]
1047
+ q = self.w2(s1 * r2 // (1 << 23))
1048
+ return q
1049
+
1050
+
1051
+ class Gate(nn.Module):
1052
+ """
1053
+ Gating mechanism for routing inputs in a mixture-of-experts (MoE) model.
1054
+
1055
+ Attributes:
1056
+ dim (int): Dimensionality of input features.
1057
+ topk (int): Number of top experts activated for each input.
1058
+ n_groups (int): Number of groups for routing.
1059
+ topk_groups (int): Number of groups to route inputs to.
1060
+ score_func (str): Scoring function ('softmax' or 'sigmoid').
1061
+ route_scale (float): Scaling factor for routing weights.
1062
+ weight (torch.nn.Parameter): Learnable weights for the gate.
1063
+ bias (Optional[torch.nn.Parameter]): Optional bias term for the gate.
1064
+ """
1065
+ def __init__(self, layer_id: int, args: ModelArgs):
1066
+ """
1067
+ Initializes the Gate module.
1068
+
1069
+ Args:
1070
+ args (ModelArgs): Model arguments containing gating parameters.
1071
+ """
1072
+ super().__init__()
1073
+
1074
+ self.layer_id = layer_id
1075
+
1076
+ self.dim = args.dim
1077
+ # n_activated_experts = 8
1078
+ self.topk = args.n_activated_experts
1079
+ # n_expert_groups = 8
1080
+ self.n_groups = args.n_expert_groups
1081
+ # n_limited_groups = 4
1082
+ self.topk_groups = args.n_limited_groups
1083
+ # score_func = 'sigmoid'
1084
+ self.score_func = args.score_func
1085
+ # route_scale = 2.5
1086
+ self.route_scale = args.route_scale
1087
+ # n_routed_experts = 256
1088
+ # self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim))
1089
+ self.register_buffer("weight", torch.empty(args.n_routed_experts, args.dim, dtype=torch.int32))
1090
+ self.register_buffer("scale", torch.tensor(0, dtype=torch.int32))
1091
+ # self.bias = nn.Parameter(torch.empty(args.n_routed_experts, dtype=torch.int32)) if self.dim == 7168 else None
1092
+ if self.dim == 7168:
1093
+ self.register_buffer("bias", torch.empty(args.n_routed_experts, dtype=torch.int32))
1094
+ else:
1095
+ self.bias = None
1096
+
1097
+ # x 的 rescale 为 2^23
1098
+ def forward(self, start_pos: int, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
1099
+ """
1100
+ Forward pass for the gating mechanism.
1101
+
1102
+ Args:
1103
+ x (torch.Tensor): Input tensor.
1104
+
1105
+ Returns:
1106
+ Tuple[torch.Tensor, torch.Tensor]: Routing weights and selected expert indices.
1107
+ """
1108
+
1109
+ x = x.view(1, -1, self.dim)
1110
+
1111
+ # scores = linear(x, self.weight)
1112
+ # self.weight shape: [256, 7168]
1113
+ # 当前 scores shape: [1, seqLen, 256]
1114
+ # rescale = 2 ** self.scale.item()
1115
+ rescale = self.scale.item()
1116
+
1117
+ # scores 的 rescale 为 2^23
1118
+ scores, scores_rem = linear_int(x, self.weight, 1, 1, rescale)
1119
+ # scores = int64_bmm_with_bias(x, self.weight, bias, 1, 1, self.scale)
1120
+
1121
+ # x shape: [seqLen, 7168]
1122
+ x = x.view(-1, self.dim)
1123
+
1124
+ if self.score_func == "softmax":
1125
+ scores = scores.softmax(dim=-1, dtype=torch.float32)
1126
+ else:
1127
+ # scores = scores.sigmoid()
1128
+ C = torch.empty_like(scores, dtype=torch.int64, device='cuda')
1129
+
1130
+ if snark:
1131
+ saveTensor(f'zkdata/pos_{start_pos}/layer_{self.layer_id}/sigmoid_gate_x.bin', scores.cpu())
1132
+ saveTensor(f'zkdata/pos_{start_pos}/layer_{self.layer_id}/sigmoid_gate_r.bin', scores_rem.cpu())
1133
+
1134
+ sigmoid_q23(scores, C)
1135
+
1136
+ if snark:
1137
+ saveTensor(f'zkdata/pos_{start_pos}/layer_{self.layer_id}/sigmoid_gate_y.bin', C.cpu())
1138
+
1139
+ # 当前 scores shape: [seqLen, 256]
1140
+ scores = C.squeeze(0)
1141
+
1142
+ # bias的rescale为2^23
1143
+ original_scores = scores
1144
+ if self.bias is not None:
1145
+ # scores = scores + self.bias
1146
+ # 当前 scores shape: [seqLen, 256]
1147
+ scores = scores + self.bias
1148
+
1149
+ if snark:
1150
+ saveTensor(f'zkdata/pos_{start_pos}/layer_{self.layer_id}/gate_original_scores.bin', original_scores.contiguous().cpu())
1151
+ saveTensor(f'zkdata/pos_{start_pos}/layer_{self.layer_id}/gate_bias.bin', self.bias.view(torch.uint32).cpu())
1152
+
1153
+ # n_groups = 8
1154
+ if self.n_groups > 1:
1155
+ # x.size(0) = 8,当前 scores shape: [seqLen, 8, 32]
1156
+ scores = scores.view(x.size(0), self.n_groups, -1)
1157
+ # print(f'scores shape 111: {scores.shape}', flush=True)
1158
+ if self.bias is None:
1159
+ group_scores = scores.amax(dim=-1)
1160
+ else:
1161
+ # topk 返回 -1维度上 最大的 前 2 个值,同时返回值和索引,[0] 表示 取值,sum(-1) 再把最大的两个值相加.
1162
+ # 256维,分成8个组,每个组挑最大的两个数相加,得到 [seqLen, 8] 的结果,代表 8 个组的 最大两个值的和。
1163
+ # group_scores 的 shape: [8, 8]
1164
+ group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1)
1165
+ # print(group_scores[0], flush=True)
1166
+ # print(f'group_scores shape: {group_scores.shape}')
1167
+
1168
+ # topk_groups = 4, 从 8 个group中选择最大的 4个,返回其索引,比如返回 [[0, 2, 4, 6], ...]
1169
+ # indices shape: [seqLen, 4]
1170
+ indices = group_scores.topk(self.topk_groups, dim=-1)[1]
1171
+ # print(indices[0], flush=True)
1172
+
1173
+ # mask shape: [seqLen, 8]
1174
+ # scatter_: 按照给定索引,把某个源张量的值写入到目标张量对应位置。 Tensor.scatter_(dim, index, src, reduce=None)
1175
+ # 比如 mask 为[[False, True, False, True, False, True, False, True], ...]
1176
+ # mask: 每一行最大的4个值相对应的 mask 为 False
1177
+ mask = scores.new_ones(x.size(0), self.n_groups, dtype=bool).scatter_(1, indices, False)
1178
+ # print(mask[0], flush=True)
1179
+ # 把满足布尔 mask 的位置替换成 "-inf", mask.unsqueeze(-1) shape: [8, 8, 1]
1180
+ # 把 scores 中 淘汰掉的4个group中的每一个值设置为 "-inf",总共设置 128个 "-inf",占每一行中的一半
1181
+ # scores shape: [seqLen, 256]
1182
+ # scores = scores.masked_fill_(mask.unsqueeze(-1), float("-inf")).flatten(1)
1183
+ scores = scores.masked_fill_(mask.unsqueeze(-1), -(1 << 42)).flatten(1)
1184
+
1185
+ # 没有淘汰掉的group中的 128个值中,选择最大的8个值,返回其下标
1186
+ # self.topk = 8, indices shape: [8, 8]
1187
+ indices = torch.topk(scores, self.topk, dim=-1)[1]
1188
+ # print(indices[0], flush=True)
1189
+
1190
+ # gather 用来按照索引从一个张量中取值,按照8个最大值的下标,获取其值
1191
+ # weights shape: [8, 8]
1192
+ weights = original_scores.gather(1, indices)
1193
+
1194
+ if snark:
1195
+ saveTensor(f'zkdata/pos_{start_pos}/layer_{self.layer_id}/gate_indices.bin', indices.contiguous().cpu())
1196
+ saveTensor(f'zkdata/pos_{start_pos}/layer_{self.layer_id}/gate_weights.bin', weights.contiguous().cpu())
1197
+
1198
+ # print(f'weights shape: {weights.shape}')
1199
+ if self.score_func == "sigmoid":
1200
+ sum1 = weights.sum(dim=-1, keepdim=True)
1201
+ # weights = (weights * (2 ** 25) + sum1 // 2) // sum1
1202
+ weights = (weights * (2 ** 23)) // sum1
1203
+ # weights /= weights.sum(dim=-1, keepdim=True)
1204
+
1205
+ #self.route_scale = 2.5
1206
+ # weights *= self.route_scale
1207
+ weights = weights * 5 // 2
1208
+
1209
+ # weights = (weights.to(torch.float32) * (2 ** -23)).to(torch.bfloat16)
1210
+ # return weights.type_as(x), indices
1211
+ return weights, indices
1212
+
1213
+
1214
+ class Expert_int(nn.Module):
1215
+ """
1216
+ Expert layer for Mixture-of-Experts (MoE) models.
1217
+
1218
+ Attributes:
1219
+ w1 (nn.Module): Linear layer for input-to-hidden transformation.
1220
+ w2 (nn.Module): Linear layer for hidden-to-output transformation.
1221
+ w3 (nn.Module): Additional linear layer for feature transformation.
1222
+ """
1223
+ def __init__(self, layer_id, idx, dim: int, inter_dim: int):
1224
+ """
1225
+ Initializes the Expert layer.
1226
+
1227
+ Args:
1228
+ dim (int): Input and output dimensionality.
1229
+ inter_dim (int): Hidden layer dimensionality.
1230
+ """
1231
+ super().__init__()
1232
+ # # w1 shape: [2048, 7168]
1233
+ # self.w1 = Linear(layer_id, dim, inter_dim)
1234
+ # # w2 shape: [7168, 2048]
1235
+ # self.w2 = Linear(layer_id, inter_dim, dim)
1236
+ # # w3 shape: [2048, 7168]
1237
+ # self.w3 = Linear(layer_id, dim, inter_dim)
1238
+
1239
+ self.layer_id = layer_id
1240
+ self.idx = idx
1241
+
1242
+ self.w1 = Linear_rescale_int(layer_id, dim, inter_dim, 1, 1, torch.int32)
1243
+ self.w2 = Linear_rescale_int(layer_id, inter_dim, dim, 1, 1, torch.int32)
1244
+ self.w3 = Linear_rescale_int(layer_id, dim, inter_dim, 1, 1, torch.int32)
1245
+
1246
+ def forward(self, start_pos: int, x: torch.Tensor) -> torch.Tensor:
1247
+ """
1248
+ Forward pass for the Expert layer.
1249
+
1250
+ Args:
1251
+ x (torch.Tensor): Input tensor.
1252
+
1253
+ Returns:
1254
+ torch.Tensor: Output tensor after expert computation.
1255
+ """
1256
+
1257
+ # 返回的 shape [bsz, seqLen, 7168]
1258
+ # return self.w2(F.silu(self.w1(x)) * self.w3(x))
1259
+ # r1 shape: [bsz, seqLen, 18432], r1 rescale: 2^23
1260
+ r1 = self.w1(x)
1261
+
1262
+ # s1 = F.silu(r1)
1263
+ # s1 shape: [bsz, seqLen, 18432], s1 rescale: 2^23
1264
+ s1 = torch.empty_like(r1, dtype=torch.int64, device='cuda')
1265
+ # silu_q25(r1, s1)
1266
+
1267
+ if snark:
1268
+ saveTensor(f'zkdata/pos_{start_pos}/layer_{self.layer_id}/expert_{self.idx}_silu_x.bin', r1.contiguous().cpu())
1269
+
1270
+ silu_q23(r1, s1)
1271
+
1272
+ if snark:
1273
+ saveTensor(f'zkdata/pos_{start_pos}/layer_{self.layer_id}/expert_{self.idx}_silu_y.bin', s1.cpu())
1274
+
1275
+ # r2 rescale: 2^23
1276
+ r2 = self.w3(x)
1277
+
1278
+ # 返回的 shape [bsz, seqLen, 7168]
1279
+ q = self.w2((s1 * r2) >> 23)
1280
+ return q
1281
+
1282
+
1283
+ class MoE(nn.Module):
1284
+ """
1285
+ Mixture-of-Experts (MoE) module.
1286
+
1287
+ Attributes:
1288
+ dim (int): Dimensionality of input features.
1289
+ n_routed_experts (int): Total number of experts in the model.
1290
+ n_local_experts (int): Number of experts handled locally in distributed systems.
1291
+ n_activated_experts (int): Number of experts activated for each input.
1292
+ gate (nn.Module): Gating mechanism to route inputs to experts.
1293
+ experts (nn.ModuleList): List of expert modules.
1294
+ shared_experts (nn.Module): Shared experts applied to all inputs.
1295
+ """
1296
+ def __init__(self, layer_id, args: ModelArgs, ckpt_path):
1297
+ """
1298
+ Initializes the MoE module.
1299
+
1300
+ Args:
1301
+ args (ModelArgs): Model arguments containing MoE parameters.
1302
+ """
1303
+ super().__init__()
1304
+ self.layer_id = layer_id
1305
+ self.ckpt_path = ckpt_path
1306
+ self.dim = args.dim
1307
+ self.moe_inter_dim = args.moe_inter_dim
1308
+ assert args.n_routed_experts % world_size == 0, f"Number of experts must be divisible by world size (world_size={world_size})"
1309
+ self.n_routed_experts = args.n_routed_experts
1310
+ self.n_local_experts = args.n_routed_experts // world_size
1311
+ self.n_activated_experts = args.n_activated_experts
1312
+ self.experts_start_idx = rank * self.n_local_experts
1313
+ self.experts_end_idx = self.experts_start_idx + self.n_local_experts
1314
+ self.gate = Gate(layer_id, args)
1315
+ # moe_inter_dim = 2048
1316
+ # self.experts = nn.ModuleList([Expert(layer_id, args.dim, args.moe_inter_dim) if self.experts_start_idx <= i < self.experts_end_idx else None
1317
+ # for i in range(self.n_routed_experts)])
1318
+ # self.experts = torch.nn.ModuleList()
1319
+
1320
+ # dim = 7168, n_shared_experts = 1, moe_inter_dim = 2048
1321
+ self.shared_experts = MLP_int(layer_id, args.dim, args.n_shared_experts * args.moe_inter_dim)
1322
+
1323
+ # x 的 rescale 为 2^23, shape: [1, seqLen, 7168]
1324
+ def forward(self, start_pos: int, x: torch.Tensor) -> torch.Tensor:
1325
+ """
1326
+ Forward pass for the MoE module.
1327
+
1328
+ Args:
1329
+ x (torch.Tensor): Input tensor.
1330
+
1331
+ Returns:
1332
+ torch.Tensor: Output tensor after expert routing and computation.
1333
+ """
1334
+ # ffn_normed 的 rescale 为 2^23
1335
+ # x = (x.to(torch.float32) * (2 ** -23)).to(torch.bfloat16)
1336
+
1337
+ # z rescale: 2^23, z 的 shape [seqLen, 7168]
1338
+ z = self.shared_experts(start_pos, x)
1339
+
1340
+ # x shape 之前为: [bsz, seqLen, 7168], 之后为 [8, 7168]
1341
+ shape = x.size()
1342
+ x = x.view(-1, self.dim)
1343
+
1344
+ # weights shape: [seqLen, 8], indices shape: [seqLen, 8]
1345
+ # weights 的 rescale 为 2^23
1346
+ weights, indices = self.gate(start_pos, x)
1347
+
1348
+ # y shape: [seqLen, 7168]
1349
+ y = torch.zeros_like(x)
1350
+ # torch.bincount 用来统计非负整数张量中各个数值出现的次数,类似于直方图计数
1351
+ # torch.bincount(input, weights=None, minlength=0) -> Tensor, weights: 可选的一维浮点张量,和 input 形状一致。若提供,就不是“次数统计”,而是“权重和”
1352
+ # 统计 256 个 专家 出现的次数
1353
+ counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist()
1354
+ for i in range(self.experts_start_idx, self.experts_end_idx):
1355
+ if counts[i] == 0:
1356
+ continue
1357
+ # expert = self.experts[i]
1358
+ with torch.device("cuda"):
1359
+ expert = Expert_int(self.layer_id, i, self.dim, self.moe_inter_dim)
1360
+ # load_model(expert, f'/data3/DeepSeek-V3-Demo1/experts-{self.layer_id}/{i}.safetensors')
1361
+ expertModelPath = os.path.join(self.ckpt_path, f"experts-{self.layer_id}/{i}.safetensors")
1362
+ load_model(expert, expertModelPath)
1363
+
1364
+ # 第 idx 个 token, 专家 i 出现的编号是 top
1365
+ # 比如
1366
+ # [0, 1, 3, 2, 5, 4, 6, 9]
1367
+ # [7, 8, 3, 12, 5, 11, 6, 1]
1368
+ # [16, 10, 3, 2, 15, 4, 6, 9]
1369
+ # [10, 21, 3, 2, 5, 4, 1, 9]
1370
+ # torch.where(indices == 1) 返回的结果是 ([0, 1, 3], [1, 7, 6])
1371
+ idx, top = torch.where(indices == i)
1372
+ # expert(x[idx]) 返回的 shape [seqLen, 2048], weights[idx, top, None] 的 shape 为 [seqLen, 1], 包含一个 weight 值
1373
+ # y[idx] += expert(x[idx]) * weights[idx, top, None]
1374
+ x2 = x[idx].unsqueeze(0)
1375
+ y2 = expert(start_pos, x2)
1376
+ y2 = y2.view(-1, self.dim)
1377
+ # y[idx] += y2 * weights[idx, top, None] // (1 << 25)
1378
+ y[idx] += y2 * weights[idx, top, None] // (1 << 23)
1379
+ # z = self.shared_experts(x)
1380
+ if world_size > 1:
1381
+ dist.all_reduce(y)
1382
+ return (y + z).view(shape)
1383
+
1384
+ def getBF8PrintStr(ele):
1385
+ v = int(ele.cpu().view(torch.uint8).item())
1386
+ ex = v >> 3 & 0xF
1387
+ r = v & 0x7
1388
+
1389
+ if ex == 15 and r == 7:
1390
+ print(f'BF8 Nan: {ex} {r} !!!', flush=True)
1391
+ elif ex == 0:
1392
+ print(f'BF8 subnormal: {ex} {r} !!!', flush=True)
1393
+
1394
+ if v & 0x80:
1395
+ vstr = f'-{ex} {r}'
1396
+ else:
1397
+ vstr = f'{ex} {r}'
1398
+ return vstr
1399
+
1400
+ class Block(nn.Module):
1401
+ """
1402
+ Transformer block combining attention and feed-forward layers.
1403
+
1404
+ Attributes:
1405
+ attn (nn.Module): Attention layer (MLA).
1406
+ ffn (nn.Module): Feed-forward network (MLP or MoE).
1407
+ attn_norm (nn.Module): Layer normalization for attention.
1408
+ ffn_norm (nn.Module): Layer normalization for feed-forward network.
1409
+ """
1410
+ def __init__(self, layer_id: int, args: ModelArgs, ckpt_path):
1411
+ """
1412
+ Initializes the Transformer block.
1413
+
1414
+ Args:
1415
+ layer_id (int): Layer index in the transformer.
1416
+ args (ModelArgs): Model arguments containing block parameters.
1417
+ """
1418
+ super().__init__()
1419
+ self.layer_id = layer_id
1420
+ self.ckpt_path = ckpt_path
1421
+ self.attn = MLA(layer_id, args)
1422
+ self.ffn = MLP_int(layer_id, args.dim, args.inter_dim) if layer_id < args.n_dense_layers else MoE(layer_id, args, ckpt_path)
1423
+ # print('args.dim: ' + str(args.dim))
1424
+ # args.dim = 7168
1425
+ self.attn_norm = RMSNorm_int(args.dim, torch.int32)
1426
+ self.ffn_norm = RMSNorm_int(args.dim, torch.int32)
1427
+ # self.ffn_norm = RMSNorm(args.dim)
1428
+
1429
+ def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor:
1430
+ """
1431
+ Forward pass for the Transformer block.
1432
+
1433
+ Args:
1434
+ x (torch.Tensor): Input tensor.
1435
+ start_pos (int): Starting position in the sequence.
1436
+ freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
1437
+ mask (Optional[torch.Tensor]): Mask tensor to exclude certain positions from attention.
1438
+
1439
+ Returns:
1440
+ torch.Tensor: Output tensor after block computation.
1441
+ """
1442
+
1443
+ x_abs = x.abs()
1444
+ x_abs_min = x_abs.min().item()
1445
+ x_abs_max = x_abs.max().item()
1446
+ print(f'x abs min: {x_abs_min}, max: {x_abs_max}', flush=True)
1447
+
1448
+ # self.attn_norm(x): 在进行attention之前,先将7168维的embeding 进行 归一化
1449
+ # attn_norm 的 scale 为 2^21, x 的 scale 为 2^31
1450
+ (atten_normed, rms) = self.attn_norm(x)
1451
+
1452
+ if snark:
1453
+ os.makedirs(f'zkdata/pos_{start_pos}/layer_{self.layer_id}', exist_ok=True)
1454
+ saveTensor(f'zkdata/pos_{start_pos}/layer_{self.layer_id}/attn_norm_x.bin', x.cpu())
1455
+ saveTensor(f'zkdata/pos_{start_pos}/layer_{self.layer_id}/attn_norm_weight.bin', self.attn_norm.weight.view(torch.uint32).cpu())
1456
+ saveTensor(f'zkdata/pos_{start_pos}/layer_{self.layer_id}/attn_norm_y.bin', atten_normed.cpu())
1457
+ saveTensor(f'zkdata/pos_{start_pos}/layer_{self.layer_id}/attn_norm_rms.bin', rms.cpu())
1458
+
1459
+ # attned 的 rescale 是 2^19, shape: [1, seqLen, 7168]
1460
+ attned = self.attn(atten_normed, start_pos, freqs_cis, mask)
1461
+
1462
+ # 调整 rescale,因为 x 的 rescale 是 2^31, attned 的 rescale 是 2^19,因此要乘以 2^12
1463
+ # x = x + attned * (2 ** 10)
1464
+ x = x + attned * (2 ** 12)
1465
+
1466
+ # ffn_normed 的 rescale 为 2^23
1467
+ (ffn_normed, rms) = self.ffn_norm(x)
1468
+
1469
+ ffned = self.ffn(start_pos, ffn_normed)
1470
+ # x = x + ffned * (2 ** 6)
1471
+ x = x + ffned * (2 ** 8)
1472
+
1473
+ # 返回的 x 的rescale 为 2^31
1474
+ return x
1475
+
1476
+ # Transformer 类在初始化中就已经明确好了自己的进程(rank),并且可以发现它是由比较经典的transformer组件构成的:
1477
+ # embedding层(self.embed)、堆叠的decoding block(self.layers),标准的RMSnorm层(self.norm)与最后将隐藏状态投射到词表分布的output层(self.head)
1478
+ # 根据前面提及的初始化的参数来看,词表大小为129280,模型的hidden dim为7168,堆叠的decode block一共有61个。维度变换会在下面举例说明。
1479
+ # Transformer 由61个Block组成,每个Block有 attn 和 ffd
1480
+ # Transformer类在初始化中就已经明确好了自己的进程(rank),并且可以发现它是由比较经典的transformer组件构成的
1481
+ # embedding层(self.embed)、堆叠的decoding block(self.layers),标准的RMSnorm层(self.norm)与最后将隐藏状态投射到词表分布的output层(self.head)。
1482
+ class Transformer(nn.Module):
1483
+ """
1484
+ Transformer model with positional embeddings, multiple layers, and output projection.
1485
+
1486
+ Attributes:
1487
+ max_seq_len (int): Maximum sequence length for the transformer.
1488
+ embed (nn.Module): Embedding layer for input tokens.
1489
+ layers (torch.nn.ModuleList): List of transformer blocks.
1490
+ norm (nn.Module): Layer normalization applied after all blocks.
1491
+ head (nn.Module): Output projection layer mapping to vocabulary size.
1492
+ freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary(旋转的) embeddings.
1493
+ """
1494
+ def __init__(self, args: ModelArgs):
1495
+ """
1496
+ Initializes the Transformer model.
1497
+
1498
+ Args:
1499
+ args (ModelArgs): Model arguments containing transformer parameters.
1500
+ """
1501
+ global world_size, rank
1502
+ world_size = dist.get_world_size() if dist.is_initialized() else 1
1503
+ rank = dist.get_rank() if dist.is_initialized() else 0
1504
+ Linear.dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16
1505
+ super().__init__()
1506
+ self.args = args
1507
+ self.max_seq_len = args.max_seq_len
1508
+ self.embed = ParallelEmbedding(args.vocab_size, args.dim)
1509
+ self.layers = torch.nn.ModuleList()
1510
+ for layer_id in range(args.n_layers):
1511
+ # self.layers.append(Block(layer_id, args))
1512
+ self.layers.append(nn.Module())
1513
+
1514
+ self.norm = RMSNorm_int(args.dim, torch.int64)
1515
+ # self.head = ColumnParallelLinear(-1, args.dim, args.vocab_size, dtype=torch.get_default_dtype())
1516
+ # 模型中的 head 的 rescale 为 2^43, 使用的过程中的rescale为 2^35, head 输入的 rescale为 2^15, 输出的 rescale为 2^21
1517
+ # self.head = ColumnParallelLinear_int(-1, args.dim, args.vocab_size, 1, (1 << 8), (1 << 29), torch.int64)
1518
+ self.head = ColumnParallelLinear_int(-1, args.dim, args.vocab_size, 1, (1 << 8), 29, torch.int64)
1519
+ # self.head = ColumnParallelLinear_int(-1, args.dim, args.vocab_size, 1, (1 << 8), (1 << 31), torch.int64)
1520
+ # self.head = ColumnParallelLinear_int(-1, args.dim, args.vocab_size, (1 << 5), (1 << 11), (1 << 21), torch.int64)
1521
+ # register_buffer()注册了名为 "freqs_cis" 的缓冲区,缓冲区的值由 precompute_freqs_cis(args) 提供,并且由于设置了 persistent=False,
1522
+ # 该缓冲区不会被保存到模型的状态字典中。缓冲区注册的张量是该Transformer类的位置编码。
1523
+ # register_buffer 用于注册一个非参数张量(tensor),这个张量虽然不是模型的可学习参数,但仍然是模型状态的一部分。
1524
+ # 与参数不同,缓冲区不会在反向传播中计算梯度,也不会被优化器更新,但它会随模型一起移动到相应的设备(如 GPU)上。
1525
+ # persistent=False表示这个参数表示该缓冲区不属于持久状态(persistent state)。也就是说,当你调用 model.state_dict() 保存模型时,
1526
+ # 这个缓冲区不会被包含进去。位置编码可以在模型加载后重新计算,不需要存储。
1527
+ self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False)
1528
+
1529
+ @torch.inference_mode()
1530
+ def prep_inference(self, tokens: torch.Tensor, start_pos: int = 0):
1531
+ # softmax_init()
1532
+ softmax_init_q19()
1533
+ softmax_init_q21()
1534
+ silu_init_q23()
1535
+
1536
+ seqlen = tokens.size(1)
1537
+
1538
+ # h 是经过embed之后的结果,embed将文本表达转化为词嵌入,h的形状为 (batch_size, seq_len, 7168)
1539
+ h = self.embed(tokens)
1540
+ # h = h.to(torch.bfloat16) * (1.0 / (1 << 44))
1541
+
1542
+ return (h, start_pos, seqlen)
1543
+
1544
+ @torch.inference_mode()
1545
+ def layer_inference(self, layer_id, h, start_pos, seqlen):
1546
+ freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen]
1547
+ mask = None
1548
+
1549
+ # triu = triangle up
1550
+ # 返回上三角矩阵
1551
+ # 参数 k=0 代表主对角线,k 为正数则从主对角线开始向上数第 k 条,k 为负数则从主对角线开始向下数第 k 条
1552
+ if seqlen > 1:
1553
+ # mask = torch.full((seqlen, seqlen), float("-inf"), device="cuda").triu_(1)
1554
+ mask = torch.full((seqlen, seqlen), -(64 << 36), dtype=torch.int64, device="cuda").triu_(1)
1555
+
1556
+ h = self.layers[layer_id](h, start_pos, freqs_cis, mask)
1557
+
1558
+ h_abs = (h.to(torch.float32) * (2 ** -31)).to(torch.bfloat16).abs()
1559
+ h_abs_max = h_abs.max()
1560
+ h_abs[h_abs < (2 ** -125)] = h_abs_max
1561
+ h_abs_min = h_abs.min()
1562
+ h_abs_min_str = getBF16PrintStr(h_abs_min)
1563
+ h_abs_max_str = getBF16PrintStr(h_abs_max)
1564
+ print(f'h_abs min: {h_abs_min_str}, max: {h_abs_max_str}')
1565
+
1566
+ # 返回的 h 的rescale 为 2^31
1567
+ return h
1568
+
1569
+ @torch.inference_mode()
1570
+ def finish_inference(self, h):
1571
+ # norm的结果的scale = 2^15, h 的 scale = 2^15
1572
+ h = self.norm(h)[0][:, -1]
1573
+
1574
+ # logits 的rescale 为 2^21
1575
+ logits = self.head(h[None, :])
1576
+ if world_size > 1:
1577
+ all_logits = [torch.empty_like(logits) for _ in range(world_size)]
1578
+ dist.all_gather(all_logits, logits)
1579
+ logits = torch.cat(all_logits, dim=-1)
1580
+
1581
+ # logits 的 scale = 2^21
1582
+ return logits
1583
+
1584
+ # # 这里开始推理了,torch.inference_mode 这句话 关闭梯度计算 并 禁止 autograd 构建计算图,同时比 torch.no_grad() 还高效,专门为推理场景优化
1585
+ # @torch.inference_mode()
1586
+ # def forward(self, tokens: torch.Tensor, start_pos: int = 0):
1587
+ # """
1588
+ # Forward pass for the Transformer model.
1589
+
1590
+ # Args:
1591
+ # tokens (torch.Tensor): Input tensor of token IDs with shape (batch_size, seq_len).
1592
+ # start_pos (int, optional): Starting position in the sequence for rotary(旋转的) embeddings. Defaults to 0.
1593
+
1594
+ # Returns:
1595
+ # torch.Tensor: Logits tensor of shape (batch_size, vocab_size).
1596
+ # """
1597
+ # seqlen = tokens.size(1)
1598
+ # # h 是经过embed之后的结果,embed将文本表达转化为词嵌入,h的形状为 (batch_size, seq_len, 7168)
1599
+ # h = self.embed(tokens)
1600
+ # freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen]
1601
+ # print('freqs_cis: ' + str(freqs_cis.tolist()))
1602
+
1603
+ # mask = None
1604
+
1605
+ # # triu = triangle up
1606
+ # # 返回上三角矩阵
1607
+ # # 参数 k=0 代表主对角线,k 为正数则从主对角线开始向上数第 k 条,k 为负数则从主对角线开始向下数第 k 条
1608
+ # if seqlen > 1:
1609
+ # mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1)
1610
+
1611
+ # for layer in self.layers:
1612
+ # h = layer(h, start_pos, freqs_cis, mask)
1613
+
1614
+ # # 只取最后一个 token
1615
+ # h = self.norm(h)[:, -1]
1616
+ # logits = self.head(h)
1617
+ # if world_size > 1:
1618
+ # all_logits = [torch.empty_like(logits) for _ in range(world_size)]
1619
+ # dist.all_gather(all_logits, logits)
1620
+ # logits = torch.cat(all_logits, dim=-1)
1621
+ # return logits
1622
+
1623
+
1624
+ if __name__ == "__main__":
1625
+ torch.set_default_dtype(torch.bfloat16)
1626
+ torch.set_default_device("cuda")
1627
+ torch.manual_seed(0)
1628
+ args = ModelArgs()
1629
+ x = torch.randint(0, args.vocab_size, (2, 128))
1630
+ model = Transformer(0, args)
1631
+ print(model(x).size())
inference/requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch==2.4.1
2
+ triton==3.0.0
3
+ transformers==4.46.3
4
+ safetensors==0.4.5