asd52403 commited on
Commit
a96927f
·
1 Parent(s): da67f74

init commit

Browse files
inference/convert2.py CHANGED
@@ -32,7 +32,7 @@ mapping = {
32
  }
33
 
34
  EmbedsInOneFile = 256
35
- EmbedsZKDir = "zkdata/embeds/"
36
 
37
  wkv_b_1_rescales = [32, 34, 37, 36, 33, 32, 33, 33, 30, 32,
38
  32, 30, 31, 30, 29, 30, 29, 30, 29, 29,
 
32
  }
33
 
34
  EmbedsInOneFile = 256
35
+ EmbedsZKDir = "../zkdata/embeds/"
36
 
37
  wkv_b_1_rescales = [32, 34, 37, 36, 33, 32, 33, 33, 30, 32,
38
  32, 30, 31, 30, 29, 30, 29, 30, 29, 29,
inference/generate.py CHANGED
@@ -14,7 +14,8 @@ from model import Transformer, ModelArgs, Block
14
  from concurrent.futures import ThreadPoolExecutor
15
  from kernel import softmax_q21, softmax_q19
16
 
17
- snark = False
 
18
 
19
  model = None
20
  kv_caches = [ torch.zeros(1, 4096 * 4, 512, dtype=torch.int64) ] * 61
@@ -215,8 +216,8 @@ def generate(
215
  print(str(cur_pos) + ' ---------- token list: ' + str(tokens[0][prev_pos:cur_pos].tolist()), flush=True)
216
 
217
  if snark:
218
- os.makedirs(f'zkdata/pos_{prev_pos}', exist_ok=True)
219
- saveTensor(f'zkdata/pos_{prev_pos}/tokens.bin', tokens[0][prev_pos:cur_pos].cpu())
220
 
221
  # logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
222
 
 
14
  from concurrent.futures import ThreadPoolExecutor
15
  from kernel import softmax_q21, softmax_q19
16
 
17
+ snark = True
18
+ zkDataDir = '../zkdata'
19
 
20
  model = None
21
  kv_caches = [ torch.zeros(1, 4096 * 4, 512, dtype=torch.int64) ] * 61
 
216
  print(str(cur_pos) + ' ---------- token list: ' + str(tokens[0][prev_pos:cur_pos].tolist()), flush=True)
217
 
218
  if snark:
219
+ os.makedirs(f'{zkDataDir}/pos_{prev_pos}', exist_ok=True)
220
+ saveTensor(f'{zkDataDir}/pos_{prev_pos}/tokens.bin', tokens[0][prev_pos:cur_pos].cpu())
221
 
222
  # logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
223
 
inference/model.py CHANGED
@@ -21,7 +21,8 @@ block_size = 128
21
  gemm_impl: Literal["bf16", "fp8"] = "bf16"
22
  attn_impl: Literal["naive", "absorb"] = "absorb"
23
 
24
- snark = False
 
25
 
26
  @dataclass
27
  class ModelArgs:
@@ -772,7 +773,7 @@ class MLA(nn.Module):
772
  # q_down = self.wq_a(x)
773
 
774
  if snark:
775
- dirStr = f'zkdata/pos_{start_pos}/layer_{self.layer_id}'
776
  os.makedirs(dirStr, exist_ok=True)
777
  saveTensor(f'{dirStr}/wq_a_x.bin', x.cpu())
778
  saveTensor(f'{dirStr}/wq_a_w.bin', self.wq_a.weight.view(torch.uint32).cpu())
@@ -784,7 +785,7 @@ class MLA(nn.Module):
784
  (q_normed, rms) = self.q_norm(q_down)
785
 
786
  if snark:
787
- dirStr = f'zkdata/pos_{start_pos}/layer_{self.layer_id}'
788
  os.makedirs(dirStr, exist_ok=True)
789
  saveTensor(f'{dirStr}/q_norm_x.bin', q_down.cpu())
790
  saveTensor(f'{dirStr}/q_norm_weight.bin', self.q_norm.weight.view(torch.uint32).cpu())
@@ -809,13 +810,13 @@ class MLA(nn.Module):
809
  # freqs_cis 的 rescale 为 2^42, 计算之后 q_pe 的 rescale 为 2^19
810
 
811
  if snark:
812
- saveTensor(f'zkdata/pos_{start_pos}/layer_{self.layer_id}/q_pe_x.bin', q_pe.cpu())
813
- saveTensor(f'zkdata/freqs_cis.bin', freqs_cis.cpu())
814
 
815
  q_pe = apply_rotary_emb(q_pe, freqs_cis)
816
 
817
  if snark:
818
- saveTensor(f'zkdata/pos_{start_pos}/layer_{self.layer_id}/q_pe_y.bin', self.q_norm.weight.view(torch.uint32).cpu())
819
 
820
  # 获取key和value的联合表示kv(即公式41中的)和包含位置信息的key表示k_pe(即公式43中的):输入乘以向下投影矩阵wkv_a后,按照最后一个维度拆分,
821
  # 前面kv_lora_rank维作为key和value的联合表示,后面qk_rope_head_dim维添加rope位置信息(调用apply_rotary_emb)后得到包含rope位置信息的key表示;
@@ -824,7 +825,7 @@ class MLA(nn.Module):
824
  kv, kv_rem = self.wkv_a1(x)
825
 
826
  if snark:
827
- dirStr = f'zkdata/pos_{start_pos}/layer_{self.layer_id}'
828
  os.makedirs(dirStr, exist_ok=True)
829
  saveTensor(f'{dirStr}/wkv_a1_x.bin', x.cpu())
830
  saveTensor(f'{dirStr}/wkv_a1_w.bin', self.wkv_a1.weight.view(torch.uint32).cpu())
@@ -909,12 +910,12 @@ class MLA(nn.Module):
909
 
910
  # # softmax_q19 会破坏 scores 的原始数据,先拷贝一份数据
911
  if snark:
912
- saveTensor(f'zkdata/pos_{start_pos}/layer_{self.layer_id}/scores_softmax_x.bin', scores.contiguous().cpu())
913
 
914
  softmax_q19(scores.contiguous(), scores_new)
915
 
916
  if snark:
917
- saveTensor(f'zkdata/pos_{start_pos}/layer_{self.layer_id}/scores_softmax_y.bin', scores_new.cpu())
918
 
919
  if attn_impl == "naive":
920
  x = torch.einsum("bsht,bthd->bshd", scores, self.v_cache[:bsz, :end_pos])
@@ -1033,12 +1034,12 @@ class MLP_int(nn.Module):
1033
  # silu_q25(r1, s1)
1034
 
1035
  if snark:
1036
- saveTensor(f'zkdata/pos_{start_pos}/layer_{self.layer_id}/mlp_silu_x.bin', r1.contiguous().cpu())
1037
 
1038
  silu_q23(r1, s1)
1039
 
1040
  if snark:
1041
- saveTensor(f'zkdata/pos_{start_pos}/layer_{self.layer_id}/mlp_silu_y.bin', s1.cpu())
1042
 
1043
  # r2 rescale: 2^23, shape: [1, seqLen, inter_dim]
1044
  r2 = self.w3(x)
@@ -1128,13 +1129,13 @@ class Gate(nn.Module):
1128
  C = torch.empty_like(scores, dtype=torch.int64, device='cuda')
1129
 
1130
  if snark:
1131
- saveTensor(f'zkdata/pos_{start_pos}/layer_{self.layer_id}/sigmoid_gate_x.bin', scores.cpu())
1132
- saveTensor(f'zkdata/pos_{start_pos}/layer_{self.layer_id}/sigmoid_gate_r.bin', scores_rem.cpu())
1133
 
1134
  sigmoid_q23(scores, C)
1135
 
1136
  if snark:
1137
- saveTensor(f'zkdata/pos_{start_pos}/layer_{self.layer_id}/sigmoid_gate_y.bin', C.cpu())
1138
 
1139
  # 当前 scores shape: [seqLen, 256]
1140
  scores = C.squeeze(0)
@@ -1147,8 +1148,8 @@ class Gate(nn.Module):
1147
  scores = scores + self.bias
1148
 
1149
  if snark:
1150
- saveTensor(f'zkdata/pos_{start_pos}/layer_{self.layer_id}/gate_original_scores.bin', original_scores.contiguous().cpu())
1151
- saveTensor(f'zkdata/pos_{start_pos}/layer_{self.layer_id}/gate_bias.bin', self.bias.view(torch.uint32).cpu())
1152
 
1153
  # n_groups = 8
1154
  if self.n_groups > 1:
@@ -1192,8 +1193,8 @@ class Gate(nn.Module):
1192
  weights = original_scores.gather(1, indices)
1193
 
1194
  if snark:
1195
- saveTensor(f'zkdata/pos_{start_pos}/layer_{self.layer_id}/gate_indices.bin', indices.contiguous().cpu())
1196
- saveTensor(f'zkdata/pos_{start_pos}/layer_{self.layer_id}/gate_weights.bin', weights.contiguous().cpu())
1197
 
1198
  # print(f'weights shape: {weights.shape}')
1199
  if self.score_func == "sigmoid":
@@ -1265,12 +1266,12 @@ class Expert_int(nn.Module):
1265
  # silu_q25(r1, s1)
1266
 
1267
  if snark:
1268
- saveTensor(f'zkdata/pos_{start_pos}/layer_{self.layer_id}/expert_{self.idx}_silu_x.bin', r1.contiguous().cpu())
1269
 
1270
  silu_q23(r1, s1)
1271
 
1272
  if snark:
1273
- saveTensor(f'zkdata/pos_{start_pos}/layer_{self.layer_id}/expert_{self.idx}_silu_y.bin', s1.cpu())
1274
 
1275
  # r2 rescale: 2^23
1276
  r2 = self.w3(x)
@@ -1450,11 +1451,11 @@ class Block(nn.Module):
1450
  (atten_normed, rms) = self.attn_norm(x)
1451
 
1452
  if snark:
1453
- os.makedirs(f'zkdata/pos_{start_pos}/layer_{self.layer_id}', exist_ok=True)
1454
- saveTensor(f'zkdata/pos_{start_pos}/layer_{self.layer_id}/attn_norm_x.bin', x.cpu())
1455
- saveTensor(f'zkdata/pos_{start_pos}/layer_{self.layer_id}/attn_norm_weight.bin', self.attn_norm.weight.view(torch.uint32).cpu())
1456
- saveTensor(f'zkdata/pos_{start_pos}/layer_{self.layer_id}/attn_norm_y.bin', atten_normed.cpu())
1457
- saveTensor(f'zkdata/pos_{start_pos}/layer_{self.layer_id}/attn_norm_rms.bin', rms.cpu())
1458
 
1459
  # attned 的 rescale 是 2^19, shape: [1, seqLen, 7168]
1460
  attned = self.attn(atten_normed, start_pos, freqs_cis, mask)
 
21
  gemm_impl: Literal["bf16", "fp8"] = "bf16"
22
  attn_impl: Literal["naive", "absorb"] = "absorb"
23
 
24
+ snark = True
25
+ zkDataDir = '../zkdata'
26
 
27
  @dataclass
28
  class ModelArgs:
 
773
  # q_down = self.wq_a(x)
774
 
775
  if snark:
776
+ dirStr = f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}'
777
  os.makedirs(dirStr, exist_ok=True)
778
  saveTensor(f'{dirStr}/wq_a_x.bin', x.cpu())
779
  saveTensor(f'{dirStr}/wq_a_w.bin', self.wq_a.weight.view(torch.uint32).cpu())
 
785
  (q_normed, rms) = self.q_norm(q_down)
786
 
787
  if snark:
788
+ dirStr = f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}'
789
  os.makedirs(dirStr, exist_ok=True)
790
  saveTensor(f'{dirStr}/q_norm_x.bin', q_down.cpu())
791
  saveTensor(f'{dirStr}/q_norm_weight.bin', self.q_norm.weight.view(torch.uint32).cpu())
 
810
  # freqs_cis 的 rescale 为 2^42, 计算之后 q_pe 的 rescale 为 2^19
811
 
812
  if snark:
813
+ saveTensor(f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}/q_pe_x.bin', q_pe.cpu())
814
+ saveTensor(f'{zkDataDir}/freqs_cis.bin', freqs_cis.cpu())
815
 
816
  q_pe = apply_rotary_emb(q_pe, freqs_cis)
817
 
818
  if snark:
819
+ saveTensor(f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}/q_pe_y.bin', self.q_norm.weight.view(torch.uint32).cpu())
820
 
821
  # 获取key和value的联合表示kv(即公式41中的)和包含位置信息的key表示k_pe(即公式43中的):输入乘以向下投影矩阵wkv_a后,按照最后一个维度拆分,
822
  # 前面kv_lora_rank维作为key和value的联合表示,后面qk_rope_head_dim维添加rope位置信息(调用apply_rotary_emb)后得到包含rope位置信息的key表示;
 
825
  kv, kv_rem = self.wkv_a1(x)
826
 
827
  if snark:
828
+ dirStr = f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}'
829
  os.makedirs(dirStr, exist_ok=True)
830
  saveTensor(f'{dirStr}/wkv_a1_x.bin', x.cpu())
831
  saveTensor(f'{dirStr}/wkv_a1_w.bin', self.wkv_a1.weight.view(torch.uint32).cpu())
 
910
 
911
  # # softmax_q19 会破坏 scores 的原始数据,先拷贝一份数据
912
  if snark:
913
+ saveTensor(f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}/scores_softmax_x.bin', scores.contiguous().cpu())
914
 
915
  softmax_q19(scores.contiguous(), scores_new)
916
 
917
  if snark:
918
+ saveTensor(f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}/scores_softmax_y.bin', scores_new.cpu())
919
 
920
  if attn_impl == "naive":
921
  x = torch.einsum("bsht,bthd->bshd", scores, self.v_cache[:bsz, :end_pos])
 
1034
  # silu_q25(r1, s1)
1035
 
1036
  if snark:
1037
+ saveTensor(f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}/mlp_silu_x.bin', r1.contiguous().cpu())
1038
 
1039
  silu_q23(r1, s1)
1040
 
1041
  if snark:
1042
+ saveTensor(f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}/mlp_silu_y.bin', s1.cpu())
1043
 
1044
  # r2 rescale: 2^23, shape: [1, seqLen, inter_dim]
1045
  r2 = self.w3(x)
 
1129
  C = torch.empty_like(scores, dtype=torch.int64, device='cuda')
1130
 
1131
  if snark:
1132
+ saveTensor(f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}/sigmoid_gate_x.bin', scores.cpu())
1133
+ saveTensor(f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}/sigmoid_gate_r.bin', scores_rem.cpu())
1134
 
1135
  sigmoid_q23(scores, C)
1136
 
1137
  if snark:
1138
+ saveTensor(f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}/sigmoid_gate_y.bin', C.cpu())
1139
 
1140
  # 当前 scores shape: [seqLen, 256]
1141
  scores = C.squeeze(0)
 
1148
  scores = scores + self.bias
1149
 
1150
  if snark:
1151
+ saveTensor(f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}/gate_original_scores.bin', original_scores.contiguous().cpu())
1152
+ saveTensor(f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}/gate_bias.bin', self.bias.view(torch.uint32).cpu())
1153
 
1154
  # n_groups = 8
1155
  if self.n_groups > 1:
 
1193
  weights = original_scores.gather(1, indices)
1194
 
1195
  if snark:
1196
+ saveTensor(f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}/gate_indices.bin', indices.contiguous().cpu())
1197
+ saveTensor(f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}/gate_weights.bin', weights.contiguous().cpu())
1198
 
1199
  # print(f'weights shape: {weights.shape}')
1200
  if self.score_func == "sigmoid":
 
1266
  # silu_q25(r1, s1)
1267
 
1268
  if snark:
1269
+ saveTensor(f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}/expert_{self.idx}_silu_x.bin', r1.contiguous().cpu())
1270
 
1271
  silu_q23(r1, s1)
1272
 
1273
  if snark:
1274
+ saveTensor(f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}/expert_{self.idx}_silu_y.bin', s1.cpu())
1275
 
1276
  # r2 rescale: 2^23
1277
  r2 = self.w3(x)
 
1451
  (atten_normed, rms) = self.attn_norm(x)
1452
 
1453
  if snark:
1454
+ os.makedirs(f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}', exist_ok=True)
1455
+ saveTensor(f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}/attn_norm_x.bin', x.cpu())
1456
+ saveTensor(f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}/attn_norm_weight.bin', self.attn_norm.weight.view(torch.uint32).cpu())
1457
+ saveTensor(f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}/attn_norm_y.bin', atten_normed.cpu())
1458
+ saveTensor(f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}/attn_norm_rms.bin', rms.cpu())
1459
 
1460
  # attned 的 rescale 是 2^19, shape: [1, seqLen, 7168]
1461
  attned = self.attn(atten_normed, start_pos, freqs_cis, mask)
inference/runLLM.sh ADDED
@@ -0,0 +1 @@
 
 
1
+ WORLD_SIZE=1 RANK=0 LOCAL_RANK=0 MASTER_ADDR=127.0.0.1 python3 generate.py --ckpt-path /data3/DeepSeek-V3-Demo1 --config configs/config_671B.json --interactive --temperature 1.0 --max-new-tokens 200 > logs/log_$(date +%Y%m%d_%H%M%S).txt 2>&1
zk/babel.config.cjs ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ module.exports = {
2
+ presets: [['@babel/preset-env', { targets: { node: 'current' } }]],
3
+ };
zk/config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "version": 1,
3
+ "deployAliases": {}
4
+ }
zk/package.json ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "dp",
3
+ "version": "0.1.0",
4
+ "description": "",
5
+ "author": "",
6
+ "license": "Apache-2.0",
7
+ "keywords": [
8
+ "mina-zkapp",
9
+ "mina-zk-app",
10
+ "mina-dapp",
11
+ "zkapp"
12
+ ],
13
+ "type": "module",
14
+ "main": "build/src/index.js",
15
+ "types": "build/src/index.d.ts",
16
+ "scripts": {
17
+ "build": "tsc",
18
+ "buildw": "tsc --watch",
19
+ "coverage": "node --experimental-vm-modules node_modules/jest/bin/jest.js --coverage",
20
+ "format": "prettier --write --ignore-unknown **/*",
21
+ "test": "npm run build && find build/src -name '*.test.js' -exec node --test {} \\;",
22
+ "testw": "npm run build && find build/src -name '*.test.js' -exec node --test --watch {} \\;",
23
+ "lint": "npx eslint src/* --fix",
24
+ "clear-cache": "npx rimraf cache/* !cache/README.md && npx rimraf cache.json && echo \"Cache cleared successfully!\"",
25
+ "start": "node build/src/run.js"
26
+ },
27
+ "devDependencies": {
28
+ "@babel/preset-env": "^7.16.4",
29
+ "@babel/preset-typescript": "^7.16.0",
30
+ "@types/node": "^22.14.1",
31
+ "@typescript-eslint/eslint-plugin": "^5.5.0",
32
+ "@typescript-eslint/parser": "^5.5.0",
33
+ "eslint": "^8.7.0",
34
+ "eslint-plugin-o1js": "^0.4.0",
35
+ "prettier": "^2.3.2",
36
+ "ts-jest": "^29.2.4",
37
+ "typescript": "^5.4.5",
38
+ "commander": "^14.0.2"
39
+ },
40
+ "peerDependencies": {
41
+ "o1js": "^2.*"
42
+ },
43
+ "engines": {
44
+ "node": ">=18.14.0"
45
+ }
46
+ }
zk/runZK.py ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import numpy as np
3
+
4
+ zkDataDir = '../zkdata'
5
+
6
+ async def main():
7
+ sem1 = asyncio.Semaphore(1)
8
+ sem7 = asyncio.Semaphore(7)
9
+ sem8 = asyncio.Semaphore(8)
10
+ sem32 = asyncio.Semaphore(32)
11
+
12
+ async def taskEmbed():
13
+ print(f'taskEmbed')
14
+
15
+ fEmbed = open("embed.log", "a", buffering=1)
16
+ fEmbedErr = open("embedErr.log", "w", buffering=1)
17
+
18
+ data = np.fromfile(f"{zkDataDir}/pos_0/tokens.bin", dtype=np.int64)
19
+ print('xs: ', data)
20
+ dataLen = len(data)
21
+
22
+ # 计算 所有 vocabulary embedding 的 hash
23
+ async def computeHash(tokenId):
24
+ async with sem32:
25
+ p = await asyncio.create_subprocess_exec("bash", "-lc", f'node build/src/index.js computeHash embed {tokenId}',
26
+ stdout=fEmbed, stderr=fEmbedErr)
27
+ rc = await p.wait()
28
+ return (tokenId, rc)
29
+
30
+ results = await asyncio.gather(*(computeHash(i) for i in range(0, 129280) ))
31
+
32
+ # 汇集所有的 vocabulary embedding 到 hashTable.json 中
33
+ p = await asyncio.create_subprocess_exec("bash", "-lc", f'node build/src/index.js precomputeHashes embed',
34
+ stdout=fEmbed, stderr=fEmbedErr)
35
+ rc = await p.wait()
36
+
37
+ # 计算 tokens 的 root hash
38
+ p = await asyncio.create_subprocess_exec("bash", "-lc", f'node build/src/index.js computeEmbedHash embed',
39
+ stdout=fEmbed, stderr=fEmbedErr)
40
+ rc = await p.wait()
41
+
42
+ async def taskEmbedBase(rowId):
43
+ async with sem7:
44
+ p = await asyncio.create_subprocess_exec("bash", "-lc", f'node build/src/index.js embedSectionBase embed {rowId}',
45
+ stdout=fEmbed, stderr=fEmbedErr)
46
+ rc = await p.wait()
47
+ return (rowId, rc)
48
+
49
+ results = await asyncio.gather(*(taskEmbedBase(i) for i in range(0, dataLen) ))
50
+
51
+ async def taskEmbedMerge(rowId):
52
+ async with sem7:
53
+ p = await asyncio.create_subprocess_exec("bash", "-lc", f'node build/src/index.js embedSectionMerge embed {rowId}',
54
+ stdout=fEmbed, stderr=fEmbedErr)
55
+ rc = await p.wait()
56
+ return (rowId, rc)
57
+
58
+ results = await asyncio.gather(*(taskEmbedMerge(i) for i in range(0, dataLen) ))
59
+
60
+ async def taskEmbedRowsMerge():
61
+ async with sem7:
62
+ p = await asyncio.create_subprocess_exec("bash", "-lc", f'node build/src/index.js embedRowsMerge embed',
63
+ stdout=fEmbed, stderr=fEmbedErr)
64
+ rc = await p.wait()
65
+ return rc
66
+
67
+ results = await asyncio.gather((taskEmbedRowsMerge() ))
68
+
69
+ fEmbed.close()
70
+ fEmbedErr.close()
71
+
72
+ async def taskAttnNorm(name, posId, layerId):
73
+ print(f'taskAttnNorm {name}')
74
+
75
+ fLog = open(f"{name}_Norm.log", "a", buffering=1)
76
+ fErr = open(f"{name}_NormErr.log", "w", buffering=1)
77
+
78
+ data = np.fromfile(f"{zkDataDir}/pos_0/tokens.bin", dtype=np.int64)
79
+ print('xs: ', data)
80
+ dataLen = len(data)
81
+
82
+ async def taskAttnNormBase(rowId, ind):
83
+ async with sem7:
84
+ p = await asyncio.create_subprocess_exec("bash", "-lc", f'node build/src/index.js normBase {name} {posId} {layerId} {rowId} {ind}',
85
+ stdout=fLog, stderr=fErr)
86
+ rc = await p.wait()
87
+ return (rowId, rc)
88
+
89
+ if name == 'attn_norm':
90
+ results = await asyncio.gather(*(taskAttnNormBase(i, j) for i in range(0, 24) for j in (0, 32)))
91
+ elif name == 'q_norm':
92
+ results = await asyncio.gather(*(taskAttnNormBase(i, 0) for i in range(0, 24)))
93
+
94
+ async def taskAttnNormMerge(rowId, startIdx):
95
+ async with sem8:
96
+ rc = 0
97
+ for j in range(startIdx, 0, -8):
98
+ p = await asyncio.create_subprocess_exec("bash", "-lc", f'node build/src/index.js normMerge {name} {posId} {layerId} {rowId} {j}',
99
+ stdout=fLog, stderr=fErr)
100
+ rc = await p.wait()
101
+ return (rowId, rc)
102
+
103
+ if name == 'attn_norm':
104
+ results = await asyncio.gather(*(taskAttnNormMerge(i, 62) for i in range(0, 24)))
105
+ elif name == 'q_norm':
106
+ results = await asyncio.gather(*(taskAttnNormMerge(i, 30) for i in range(0, 24)))
107
+
108
+ async def normWrapRow():
109
+ async with sem7:
110
+ p = await asyncio.create_subprocess_exec("bash", "-lc", f'node build/src/index.js normWrapRow {name} {posId} {layerId}',
111
+ stdout=fLog, stderr=fErr)
112
+ rc = await p.wait()
113
+ return rc
114
+
115
+ results = await asyncio.gather(normWrapRow())
116
+
117
+ async def normMergeRow():
118
+ async with sem7:
119
+ p = await asyncio.create_subprocess_exec("bash", "-lc", f'node build/src/index.js normMergeRow {name} {posId} {layerId}',
120
+ stdout=fLog, stderr=fErr)
121
+ rc = await p.wait()
122
+ return rc
123
+
124
+ results = await asyncio.gather(normMergeRow())
125
+
126
+ fLog.close()
127
+ fErr.close()
128
+
129
+
130
+ # gate 中 experts 选择逻辑
131
+ async def taskExpertSelector(name, posId, layerId):
132
+ print(f'taskGateExpertSelector')
133
+
134
+ fgate = open(f"{name}_expertSelector.log", "a", buffering=1)
135
+ fgateErr = open(f"{name}_expertSelectorErr.log", "w", buffering=1)
136
+
137
+ async def taskGroupBase(rowId):
138
+ async with sem8:
139
+ p = await asyncio.create_subprocess_exec("bash", "-lc", f'node build/src/index.js expertsGroupBase {name} {posId} {layerId} {rowId}',
140
+ stdout=fgate, stderr=fgateErr)
141
+ rc = await p.wait()
142
+ return (rowId, rc)
143
+
144
+ results = await asyncio.gather(*(taskGroupBase(i) for i in range(0, 24) ))
145
+
146
+ async def taskGroupMerge(rowId):
147
+ async with sem8:
148
+ p = await asyncio.create_subprocess_exec("bash", "-lc", f'node build/src/index.js expertsGroupMerge {name} {posId} {layerId} {rowId}',
149
+ stdout=fgate, stderr=fgateErr)
150
+ rc = await p.wait()
151
+ return (rowId, rc)
152
+
153
+ results = await asyncio.gather(*(taskGroupMerge(i) for i in range(0, 24) ))
154
+
155
+ async def taskSortedGroupBase(rowId):
156
+ async with sem8:
157
+ p = await asyncio.create_subprocess_exec("bash", "-lc", f'node build/src/index.js expertsSortedGroupBase {name} {posId} {layerId} {rowId}',
158
+ stdout=fgate, stderr=fgateErr)
159
+ rc = await p.wait()
160
+ return (rowId, rc)
161
+
162
+ results = await asyncio.gather(*(taskSortedGroupBase(i) for i in range(0, 24) ))
163
+
164
+ async def taskSortedGroupMerge(rowId):
165
+ async with sem8:
166
+ p = await asyncio.create_subprocess_exec("bash", "-lc", f'node build/src/index.js expertsSortedGroupMerge {name}s {posId} {layerId} {rowId}',
167
+ stdout=fgate, stderr=fgateErr)
168
+ rc = await p.wait()
169
+ return (rowId, rc)
170
+
171
+ results = await asyncio.gather(*(taskSortedGroupMerge(i) for i in range(0, 24) ))
172
+
173
+ async def taskSelectorBase(rowId):
174
+ async with sem8:
175
+ p = await asyncio.create_subprocess_exec("bash", "-lc", f'node build/src/index.js expertsSelectorBase {name} {posId} {layerId} {rowId}',
176
+ stdout=fgate, stderr=fgateErr)
177
+ rc = await p.wait()
178
+ return (rowId, rc)
179
+
180
+ results = await asyncio.gather(*(taskSelectorBase(i) for i in range(0, 24) ))
181
+
182
+ async def taskSelectorMerge():
183
+ async with sem8:
184
+ p = await asyncio.create_subprocess_exec("bash", "-lc", f'node build/src/index.js expertsSelectorMerge {name} {posId} {layerId}',
185
+ stdout=fgate, stderr=fgateErr)
186
+ rc = await p.wait()
187
+ return rc
188
+
189
+ results = await asyncio.gather((taskSelectorMerge() ))
190
+
191
+ fgate.close()
192
+ fgateErr.close()
193
+
194
+
195
+ async def taskRope_pe():
196
+ print(f'taskRope')
197
+
198
+ fLog = open("rope.log", "a", buffering=1)
199
+ fErr = open("ropeErr.log", "w", buffering=1)
200
+
201
+ data = np.fromfile(f"{zkDataDir}/pos_0/tokens.bin", dtype=np.int64)
202
+ print('xs: ', data)
203
+ dataLen = len(data)
204
+
205
+ async def ropeBase(name, posId, layerId, rowId, ind, f_out, f_err):
206
+ async with sem7:
207
+ p = await asyncio.create_subprocess_exec("bash", "-lc", f'node build/src/index.js ropeBase {name} {posId} {layerId} {rowId} {ind}',
208
+ stdout=f_out, stderr=f_err)
209
+ rc = await p.wait()
210
+ return rc
211
+
212
+ results = await asyncio.gather(*(ropeBase('q_pe', 0, 0, i, j, fLog, fErr) for i in range(0, 24) for j in (0, 32, 64, 96) ))
213
+
214
+ async def ropeMerge(name, posId, layerId, rowId, ind, f_out, f_err):
215
+ async with sem8:
216
+ p = await asyncio.create_subprocess_exec("bash", "-lc", f'node build/src/index.js ropeMerge {name} {posId} {layerId} {rowId} {ind}',
217
+ stdout=f_out, stderr=f_err)
218
+ rc = await p.wait()
219
+ return rc
220
+
221
+ for j in range(126, -1, -8):
222
+ results = await asyncio.gather(*(ropeMerge('q_pe', 0, 0, i, j, fLog, fErr) for i in range(0, 24) ))
223
+
224
+ async def wrapRopeRow(name, posId, layerId, f_out, f_err):
225
+ p = await asyncio.create_subprocess_exec("bash", "-lc", f'node build/src/index.js wrapRopeRow {name} {posId} {layerId}',
226
+ stdout=f_out, stderr=f_err)
227
+ rc = await p.wait()
228
+ return rc
229
+
230
+ results = await asyncio.gather(wrapRopeRow('q_pe', 0, 0, fLog, fErr))
231
+
232
+ async def mergeRopeRow(name, posId, layerId, f_out, f_err):
233
+ p = await asyncio.create_subprocess_exec("bash", "-lc", f'node build/src/index.js mergeRopeRow {name} {posId} {layerId}',
234
+ stdout=f_out, stderr=f_err)
235
+ rc = await p.wait()
236
+ return rc
237
+
238
+ results = await asyncio.gather(mergeRopeRow('q_pe', 0, 0, fLog, fErr))
239
+
240
+ fLog.close()
241
+ fErr.close()
242
+
243
+ async def taskSoftmax(name):
244
+ print(f'taskSoftmax {name}')
245
+
246
+ fLog = open(f"{name}_softmax.log", "a", buffering=1)
247
+ fErr = open(f"{name}_softmaxErr.log", "w", buffering=1)
248
+
249
+ async def softmaxHeadBase(posId, layerId, rowId, headId, headDim):
250
+ async with sem7:
251
+ p = await asyncio.create_subprocess_exec(
252
+ "bash", "-lc",
253
+ f'node build/src/index.js softmaxHeadBase {name} {posId} {layerId} {rowId} {headId} {headDim}',
254
+ stdout=fLog, stderr=fErr)
255
+ rc = await p.wait()
256
+ return rc
257
+ results = await asyncio.gather(*(softmaxHeadBase(0, 0, i, j, 24) for i in range(0, 24) for j in range(0, 128, 4)))
258
+
259
+ async def softmaxHeadMerge(posId, layerId, rowId, headDim):
260
+ async with sem8:
261
+ rc = 0
262
+ for ind in range(126, -1, -8):
263
+ p = await asyncio.create_subprocess_exec(
264
+ "bash", "-lc",
265
+ f'node build/src/index.js softmaxHeadMerge {name} {posId} {layerId} {rowId} {ind} {headDim}',
266
+ stdout=fLog, stderr=fErr)
267
+ rc = await p.wait()
268
+ return (rowId, rc)
269
+ results = await asyncio.gather(*(softmaxHeadMerge(0, 0, i, 24) for i in range(0, 24)))
270
+
271
+ async def softmaxWrapRow(posId, layerId):
272
+ p = await asyncio.create_subprocess_exec("bash", "-lc", f'node build/src/index.js softmaxWrapRow {name} {posId} {layerId}',
273
+ stdout=fLog, stderr=fErr)
274
+ rc = await p.wait()
275
+ return rc
276
+ results = await asyncio.gather(softmaxWrapRow(0, 0))
277
+
278
+ async def softmaxMergeRow(posId, layerId):
279
+ p = await asyncio.create_subprocess_exec("bash", "-lc", f'node build/src/index.js softmaxMergeRow {name} {posId} {layerId}',
280
+ stdout=fLog, stderr=fErr)
281
+ rc = await p.wait()
282
+ return rc
283
+ results = await asyncio.gather(softmaxMergeRow(0, 0))
284
+
285
+ fLog.close()
286
+ fErr.close()
287
+
288
+ async def taskSigmoid(name):
289
+ print(f'taskSigmoid {name}')
290
+
291
+ fLog = open(f"{name}_sigmoid.log", "a", buffering=1)
292
+ fErr = open(f"{name}_sigmoidErr.log", "w", buffering=1)
293
+
294
+ async def sigmoidSectionBase(posId, layerId, rowId):
295
+ async with sem8:
296
+ p = await asyncio.create_subprocess_exec("bash", "-lc", f'node build/src/index.js sigmoidSectionBase {name} {posId} {layerId} {rowId}',
297
+ stdout=fLog, stderr=fErr)
298
+ rc = await p.wait()
299
+ return rc
300
+
301
+ results = await asyncio.gather(*(sigmoidSectionBase(0, 3, i) for i in range(0, 24) ))
302
+
303
+ async def sigmoidSectionMerge(posId, layerId, rowId):
304
+ async with sem8:
305
+ p = await asyncio.create_subprocess_exec("bash", "-lc", f'node build/src/index.js sigmoidSectionMerge {name} {posId} {layerId} {rowId}',
306
+ stdout=fLog, stderr=fErr)
307
+ rc = await p.wait()
308
+ return rc
309
+
310
+ results = await asyncio.gather(*(sigmoidSectionMerge(0, 3, i) for i in range(0, 24) ))
311
+
312
+ async def sigmoidRowBase(posId, layerId):
313
+ p = await asyncio.create_subprocess_exec("bash", "-lc", f'node build/src/index.js sigmoidRowBase {name} {posId} {layerId}',
314
+ stdout=fLog, stderr=fErr)
315
+ rc = await p.wait()
316
+ return rc
317
+ results = await asyncio.gather(sigmoidRowBase(0, 3))
318
+
319
+ async def sigmoidRowMerge(posId, layerId):
320
+ p = await asyncio.create_subprocess_exec("bash", "-lc", f'node build/src/index.js sigmoidRowMerge {name} {posId} {layerId}',
321
+ stdout=fLog, stderr=fErr)
322
+ rc = await p.wait()
323
+ return rc
324
+ results = await asyncio.gather(sigmoidRowMerge(0, 3))
325
+
326
+ fLog.close()
327
+ fErr.close()
328
+
329
+ async def taskGemm(name, posId, layerId, InDim, OutDim, ShortDim):
330
+ print(f'taskGemm {name}')
331
+
332
+ fLog = open(f"{name}_gemm.log", "a", buffering=1)
333
+ fErr = open(f"{name}_gemmErr.log", "w", buffering=1)
334
+
335
+ data = np.fromfile(f"{zkDataDir}/pos_{posId}/tokens.bin", dtype=np.int64)
336
+ print('xs: ', data)
337
+ rowCount = len(data)
338
+
339
+ segmentCount = InDim // ShortDim
340
+ startIndArr = [i * 32 for i in range(0, segmentCount // 32)]
341
+ print('startIndArr: ', startIndArr)
342
+
343
+ async def gemmXBase(rowId, ind):
344
+ async with sem8:
345
+ p = await asyncio.create_subprocess_exec("bash", "-lc", f'node build/src/index.js gemmXBase {name} {posId} {layerId} {rowId} {ind}',
346
+ stdout=fLog, stderr=fErr)
347
+ rc = await p.wait()
348
+ return rc
349
+ results = await asyncio.gather(*(gemmXBase(i, j) for i in range(0, rowCount) for j in startIndArr))
350
+
351
+ async def gemmXMergeRow(ind):
352
+ async with sem8:
353
+ p = await asyncio.create_subprocess_exec("bash", "-lc", f'node build/src/index.js gemmXMergeRow {name} {posId} {layerId} {ind}',
354
+ stdout=fLog, stderr=fErr)
355
+ rc = await p.wait()
356
+ return rc
357
+ results = await asyncio.gather(*(gemmXMergeRow(j) for j in range(segmentCount - 1, 2 * segmentCount - 1)))
358
+
359
+ async def gemmWBase(rowId, ind):
360
+ async with sem8:
361
+ p = await asyncio.create_subprocess_exec("bash", "-lc", f'node build/src/index.js gemmWBase {name} {posId} {layerId} {rowId} {ind}',
362
+ stdout=fLog, stderr=fErr)
363
+ rc = await p.wait()
364
+ return rc
365
+ results = await asyncio.gather(*(gemmWBase(i, j) for i in range(0, OutDim) for j in startIndArr ))
366
+
367
+ async def gemmWMergeRow(ind):
368
+ async with sem8:
369
+ rc = 0
370
+ for rowIndex in range(1, 512, 32):
371
+ p = await asyncio.create_subprocess_exec("bash", "-lc", f'node build/src/index.js gemmWMergeRow {name} {posId} {layerId} {ind} {rowIndex}',
372
+ stdout=fLog, stderr=fErr)
373
+ rc = await p.wait()
374
+ return rc
375
+ results = await asyncio.gather(*(gemmWMergeRow(j) for j in range(segmentCount - 1, 2 * segmentCount - 1) ))
376
+
377
+ async def gemmXWBase(ind):
378
+ async with sem8:
379
+ p = await asyncio.create_subprocess_exec("bash", "-lc", f'node build/src/index.js gemmXWBase {name} {posId} {layerId} {ind}',
380
+ stdout=fLog, stderr=fErr)
381
+ rc = await p.wait()
382
+ return rc
383
+ results = await asyncio.gather(*(gemmXWBase(i) for i in startIndArr))
384
+
385
+ async def gemmXWMerge(ind):
386
+ async with sem1:
387
+ p = await asyncio.create_subprocess_exec("bash", "-lc", f'node build/src/index.js gemmXWMerge {name} {posId} {layerId} {ind}',
388
+ stdout=fLog, stderr=fErr)
389
+ rc = await p.wait()
390
+ return rc
391
+ results = await asyncio.gather(*(gemmXWMerge(i) for i in range(segmentCount - 2, -1, -8)))
392
+
393
+ fLog.close()
394
+ fErr.close()
395
+
396
+ # await taskExpertSelector_gate(0, 4)
397
+ await taskEmbed()
398
+ # await taskAttnNorm('attn_norm', 0, 0)
399
+ # await taskAttnNorm('q_norm')
400
+ # await taskRope_pe()
401
+ # await taskSoftmax('scores')
402
+ # await taskSigmoid('gate')
403
+ # await taskExpertSelector('gate', 0, 3)
404
+ # await taskGemm('wkv_a1', 0, 0, 7168, 512, 112)
405
+
406
+ print("all done.")
407
+
408
+ asyncio.run(main())
zk/src/.DS_Store ADDED
Binary file (6.15 kB). View file
 
zk/src/index.ts ADDED
The diff for this file is too large to render. See raw diff
 
zk/tsconfig.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "compilerOptions": {
3
+ "target": "es2021", // goal: ship *the most modern syntax* that is supported by *all* browsers that support our Wasm
4
+ "module": "nodenext", // allow top-level await
5
+ "lib": ["dom", "esnext"],
6
+ "outDir": "./build",
7
+ "rootDir": ".",
8
+ "strict": true,
9
+ "strictPropertyInitialization": false, // to enable generic constructors, e.g. on CircuitValue
10
+ "skipLibCheck": true,
11
+ "forceConsistentCasingInFileNames": true,
12
+ "esModuleInterop": true,
13
+ "moduleResolution": "nodenext", // comply with node + "type": "module"
14
+ "experimentalDecorators": true, // needed for decorators used in o1js
15
+ "emitDecoratorMetadata": true, // needed for decorators used in o1js
16
+ "allowJs": true, // to use JSDoc in some places where TS would be too cumbersome
17
+ "declaration": true,
18
+ "sourceMap": true,
19
+ "noFallthroughCasesInSwitch": true,
20
+ "allowSyntheticDefaultImports": true,
21
+ "useDefineForClassFields": false, // ensure correct behaviour of class fields with decorators
22
+ "importHelpers": true, // bundle optimization to reduce size
23
+ "baseUrl": "." // base directory for module resolution
24
+ },
25
+ "include": ["./src"]
26
+ }