init commit
Browse files- inference/convert2.py +1 -1
- inference/generate.py +4 -3
- inference/model.py +26 -25
- inference/runLLM.sh +1 -0
- zk/babel.config.cjs +3 -0
- zk/config.json +4 -0
- zk/package.json +46 -0
- zk/runZK.py +408 -0
- zk/src/.DS_Store +0 -0
- zk/src/index.ts +0 -0
- zk/tsconfig.json +26 -0
inference/convert2.py
CHANGED
|
@@ -32,7 +32,7 @@ mapping = {
|
|
| 32 |
}
|
| 33 |
|
| 34 |
EmbedsInOneFile = 256
|
| 35 |
-
EmbedsZKDir = "zkdata/embeds/"
|
| 36 |
|
| 37 |
wkv_b_1_rescales = [32, 34, 37, 36, 33, 32, 33, 33, 30, 32,
|
| 38 |
32, 30, 31, 30, 29, 30, 29, 30, 29, 29,
|
|
|
|
| 32 |
}
|
| 33 |
|
| 34 |
EmbedsInOneFile = 256
|
| 35 |
+
EmbedsZKDir = "../zkdata/embeds/"
|
| 36 |
|
| 37 |
wkv_b_1_rescales = [32, 34, 37, 36, 33, 32, 33, 33, 30, 32,
|
| 38 |
32, 30, 31, 30, 29, 30, 29, 30, 29, 29,
|
inference/generate.py
CHANGED
|
@@ -14,7 +14,8 @@ from model import Transformer, ModelArgs, Block
|
|
| 14 |
from concurrent.futures import ThreadPoolExecutor
|
| 15 |
from kernel import softmax_q21, softmax_q19
|
| 16 |
|
| 17 |
-
snark =
|
|
|
|
| 18 |
|
| 19 |
model = None
|
| 20 |
kv_caches = [ torch.zeros(1, 4096 * 4, 512, dtype=torch.int64) ] * 61
|
|
@@ -215,8 +216,8 @@ def generate(
|
|
| 215 |
print(str(cur_pos) + ' ---------- token list: ' + str(tokens[0][prev_pos:cur_pos].tolist()), flush=True)
|
| 216 |
|
| 217 |
if snark:
|
| 218 |
-
os.makedirs(f'
|
| 219 |
-
saveTensor(f'
|
| 220 |
|
| 221 |
# logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
|
| 222 |
|
|
|
|
| 14 |
from concurrent.futures import ThreadPoolExecutor
|
| 15 |
from kernel import softmax_q21, softmax_q19
|
| 16 |
|
| 17 |
+
snark = True
|
| 18 |
+
zkDataDir = '../zkdata'
|
| 19 |
|
| 20 |
model = None
|
| 21 |
kv_caches = [ torch.zeros(1, 4096 * 4, 512, dtype=torch.int64) ] * 61
|
|
|
|
| 216 |
print(str(cur_pos) + ' ---------- token list: ' + str(tokens[0][prev_pos:cur_pos].tolist()), flush=True)
|
| 217 |
|
| 218 |
if snark:
|
| 219 |
+
os.makedirs(f'{zkDataDir}/pos_{prev_pos}', exist_ok=True)
|
| 220 |
+
saveTensor(f'{zkDataDir}/pos_{prev_pos}/tokens.bin', tokens[0][prev_pos:cur_pos].cpu())
|
| 221 |
|
| 222 |
# logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
|
| 223 |
|
inference/model.py
CHANGED
|
@@ -21,7 +21,8 @@ block_size = 128
|
|
| 21 |
gemm_impl: Literal["bf16", "fp8"] = "bf16"
|
| 22 |
attn_impl: Literal["naive", "absorb"] = "absorb"
|
| 23 |
|
| 24 |
-
snark =
|
|
|
|
| 25 |
|
| 26 |
@dataclass
|
| 27 |
class ModelArgs:
|
|
@@ -772,7 +773,7 @@ class MLA(nn.Module):
|
|
| 772 |
# q_down = self.wq_a(x)
|
| 773 |
|
| 774 |
if snark:
|
| 775 |
-
dirStr = f'
|
| 776 |
os.makedirs(dirStr, exist_ok=True)
|
| 777 |
saveTensor(f'{dirStr}/wq_a_x.bin', x.cpu())
|
| 778 |
saveTensor(f'{dirStr}/wq_a_w.bin', self.wq_a.weight.view(torch.uint32).cpu())
|
|
@@ -784,7 +785,7 @@ class MLA(nn.Module):
|
|
| 784 |
(q_normed, rms) = self.q_norm(q_down)
|
| 785 |
|
| 786 |
if snark:
|
| 787 |
-
dirStr = f'
|
| 788 |
os.makedirs(dirStr, exist_ok=True)
|
| 789 |
saveTensor(f'{dirStr}/q_norm_x.bin', q_down.cpu())
|
| 790 |
saveTensor(f'{dirStr}/q_norm_weight.bin', self.q_norm.weight.view(torch.uint32).cpu())
|
|
@@ -809,13 +810,13 @@ class MLA(nn.Module):
|
|
| 809 |
# freqs_cis 的 rescale 为 2^42, 计算之后 q_pe 的 rescale 为 2^19
|
| 810 |
|
| 811 |
if snark:
|
| 812 |
-
saveTensor(f'
|
| 813 |
-
saveTensor(f'
|
| 814 |
|
| 815 |
q_pe = apply_rotary_emb(q_pe, freqs_cis)
|
| 816 |
|
| 817 |
if snark:
|
| 818 |
-
saveTensor(f'
|
| 819 |
|
| 820 |
# 获取key和value的联合表示kv(即公式41中的)和包含位置信息的key表示k_pe(即公式43中的):输入乘以向下投影矩阵wkv_a后,按照最后一个维度拆分,
|
| 821 |
# 前面kv_lora_rank维作为key和value的联合表示,后面qk_rope_head_dim维添加rope位置信息(调用apply_rotary_emb)后得到包含rope位置信息的key表示;
|
|
@@ -824,7 +825,7 @@ class MLA(nn.Module):
|
|
| 824 |
kv, kv_rem = self.wkv_a1(x)
|
| 825 |
|
| 826 |
if snark:
|
| 827 |
-
dirStr = f'
|
| 828 |
os.makedirs(dirStr, exist_ok=True)
|
| 829 |
saveTensor(f'{dirStr}/wkv_a1_x.bin', x.cpu())
|
| 830 |
saveTensor(f'{dirStr}/wkv_a1_w.bin', self.wkv_a1.weight.view(torch.uint32).cpu())
|
|
@@ -909,12 +910,12 @@ class MLA(nn.Module):
|
|
| 909 |
|
| 910 |
# # softmax_q19 会破坏 scores 的原始数据,先拷贝一份数据
|
| 911 |
if snark:
|
| 912 |
-
saveTensor(f'
|
| 913 |
|
| 914 |
softmax_q19(scores.contiguous(), scores_new)
|
| 915 |
|
| 916 |
if snark:
|
| 917 |
-
saveTensor(f'
|
| 918 |
|
| 919 |
if attn_impl == "naive":
|
| 920 |
x = torch.einsum("bsht,bthd->bshd", scores, self.v_cache[:bsz, :end_pos])
|
|
@@ -1033,12 +1034,12 @@ class MLP_int(nn.Module):
|
|
| 1033 |
# silu_q25(r1, s1)
|
| 1034 |
|
| 1035 |
if snark:
|
| 1036 |
-
saveTensor(f'
|
| 1037 |
|
| 1038 |
silu_q23(r1, s1)
|
| 1039 |
|
| 1040 |
if snark:
|
| 1041 |
-
saveTensor(f'
|
| 1042 |
|
| 1043 |
# r2 rescale: 2^23, shape: [1, seqLen, inter_dim]
|
| 1044 |
r2 = self.w3(x)
|
|
@@ -1128,13 +1129,13 @@ class Gate(nn.Module):
|
|
| 1128 |
C = torch.empty_like(scores, dtype=torch.int64, device='cuda')
|
| 1129 |
|
| 1130 |
if snark:
|
| 1131 |
-
saveTensor(f'
|
| 1132 |
-
saveTensor(f'
|
| 1133 |
|
| 1134 |
sigmoid_q23(scores, C)
|
| 1135 |
|
| 1136 |
if snark:
|
| 1137 |
-
saveTensor(f'
|
| 1138 |
|
| 1139 |
# 当前 scores shape: [seqLen, 256]
|
| 1140 |
scores = C.squeeze(0)
|
|
@@ -1147,8 +1148,8 @@ class Gate(nn.Module):
|
|
| 1147 |
scores = scores + self.bias
|
| 1148 |
|
| 1149 |
if snark:
|
| 1150 |
-
saveTensor(f'
|
| 1151 |
-
saveTensor(f'
|
| 1152 |
|
| 1153 |
# n_groups = 8
|
| 1154 |
if self.n_groups > 1:
|
|
@@ -1192,8 +1193,8 @@ class Gate(nn.Module):
|
|
| 1192 |
weights = original_scores.gather(1, indices)
|
| 1193 |
|
| 1194 |
if snark:
|
| 1195 |
-
saveTensor(f'
|
| 1196 |
-
saveTensor(f'
|
| 1197 |
|
| 1198 |
# print(f'weights shape: {weights.shape}')
|
| 1199 |
if self.score_func == "sigmoid":
|
|
@@ -1265,12 +1266,12 @@ class Expert_int(nn.Module):
|
|
| 1265 |
# silu_q25(r1, s1)
|
| 1266 |
|
| 1267 |
if snark:
|
| 1268 |
-
saveTensor(f'
|
| 1269 |
|
| 1270 |
silu_q23(r1, s1)
|
| 1271 |
|
| 1272 |
if snark:
|
| 1273 |
-
saveTensor(f'
|
| 1274 |
|
| 1275 |
# r2 rescale: 2^23
|
| 1276 |
r2 = self.w3(x)
|
|
@@ -1450,11 +1451,11 @@ class Block(nn.Module):
|
|
| 1450 |
(atten_normed, rms) = self.attn_norm(x)
|
| 1451 |
|
| 1452 |
if snark:
|
| 1453 |
-
os.makedirs(f'
|
| 1454 |
-
saveTensor(f'
|
| 1455 |
-
saveTensor(f'
|
| 1456 |
-
saveTensor(f'
|
| 1457 |
-
saveTensor(f'
|
| 1458 |
|
| 1459 |
# attned 的 rescale 是 2^19, shape: [1, seqLen, 7168]
|
| 1460 |
attned = self.attn(atten_normed, start_pos, freqs_cis, mask)
|
|
|
|
| 21 |
gemm_impl: Literal["bf16", "fp8"] = "bf16"
|
| 22 |
attn_impl: Literal["naive", "absorb"] = "absorb"
|
| 23 |
|
| 24 |
+
snark = True
|
| 25 |
+
zkDataDir = '../zkdata'
|
| 26 |
|
| 27 |
@dataclass
|
| 28 |
class ModelArgs:
|
|
|
|
| 773 |
# q_down = self.wq_a(x)
|
| 774 |
|
| 775 |
if snark:
|
| 776 |
+
dirStr = f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}'
|
| 777 |
os.makedirs(dirStr, exist_ok=True)
|
| 778 |
saveTensor(f'{dirStr}/wq_a_x.bin', x.cpu())
|
| 779 |
saveTensor(f'{dirStr}/wq_a_w.bin', self.wq_a.weight.view(torch.uint32).cpu())
|
|
|
|
| 785 |
(q_normed, rms) = self.q_norm(q_down)
|
| 786 |
|
| 787 |
if snark:
|
| 788 |
+
dirStr = f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}'
|
| 789 |
os.makedirs(dirStr, exist_ok=True)
|
| 790 |
saveTensor(f'{dirStr}/q_norm_x.bin', q_down.cpu())
|
| 791 |
saveTensor(f'{dirStr}/q_norm_weight.bin', self.q_norm.weight.view(torch.uint32).cpu())
|
|
|
|
| 810 |
# freqs_cis 的 rescale 为 2^42, 计算之后 q_pe 的 rescale 为 2^19
|
| 811 |
|
| 812 |
if snark:
|
| 813 |
+
saveTensor(f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}/q_pe_x.bin', q_pe.cpu())
|
| 814 |
+
saveTensor(f'{zkDataDir}/freqs_cis.bin', freqs_cis.cpu())
|
| 815 |
|
| 816 |
q_pe = apply_rotary_emb(q_pe, freqs_cis)
|
| 817 |
|
| 818 |
if snark:
|
| 819 |
+
saveTensor(f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}/q_pe_y.bin', self.q_norm.weight.view(torch.uint32).cpu())
|
| 820 |
|
| 821 |
# 获取key和value的联合表示kv(即公式41中的)和包含位置信息的key表示k_pe(即公式43中的):输入乘以向下投影矩阵wkv_a后,按照最后一个维度拆分,
|
| 822 |
# 前面kv_lora_rank维作为key和value的联合表示,后面qk_rope_head_dim维添加rope位置信息(调用apply_rotary_emb)后得到包含rope位置信息的key表示;
|
|
|
|
| 825 |
kv, kv_rem = self.wkv_a1(x)
|
| 826 |
|
| 827 |
if snark:
|
| 828 |
+
dirStr = f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}'
|
| 829 |
os.makedirs(dirStr, exist_ok=True)
|
| 830 |
saveTensor(f'{dirStr}/wkv_a1_x.bin', x.cpu())
|
| 831 |
saveTensor(f'{dirStr}/wkv_a1_w.bin', self.wkv_a1.weight.view(torch.uint32).cpu())
|
|
|
|
| 910 |
|
| 911 |
# # softmax_q19 会破坏 scores 的原始数据,先拷贝一份数据
|
| 912 |
if snark:
|
| 913 |
+
saveTensor(f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}/scores_softmax_x.bin', scores.contiguous().cpu())
|
| 914 |
|
| 915 |
softmax_q19(scores.contiguous(), scores_new)
|
| 916 |
|
| 917 |
if snark:
|
| 918 |
+
saveTensor(f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}/scores_softmax_y.bin', scores_new.cpu())
|
| 919 |
|
| 920 |
if attn_impl == "naive":
|
| 921 |
x = torch.einsum("bsht,bthd->bshd", scores, self.v_cache[:bsz, :end_pos])
|
|
|
|
| 1034 |
# silu_q25(r1, s1)
|
| 1035 |
|
| 1036 |
if snark:
|
| 1037 |
+
saveTensor(f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}/mlp_silu_x.bin', r1.contiguous().cpu())
|
| 1038 |
|
| 1039 |
silu_q23(r1, s1)
|
| 1040 |
|
| 1041 |
if snark:
|
| 1042 |
+
saveTensor(f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}/mlp_silu_y.bin', s1.cpu())
|
| 1043 |
|
| 1044 |
# r2 rescale: 2^23, shape: [1, seqLen, inter_dim]
|
| 1045 |
r2 = self.w3(x)
|
|
|
|
| 1129 |
C = torch.empty_like(scores, dtype=torch.int64, device='cuda')
|
| 1130 |
|
| 1131 |
if snark:
|
| 1132 |
+
saveTensor(f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}/sigmoid_gate_x.bin', scores.cpu())
|
| 1133 |
+
saveTensor(f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}/sigmoid_gate_r.bin', scores_rem.cpu())
|
| 1134 |
|
| 1135 |
sigmoid_q23(scores, C)
|
| 1136 |
|
| 1137 |
if snark:
|
| 1138 |
+
saveTensor(f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}/sigmoid_gate_y.bin', C.cpu())
|
| 1139 |
|
| 1140 |
# 当前 scores shape: [seqLen, 256]
|
| 1141 |
scores = C.squeeze(0)
|
|
|
|
| 1148 |
scores = scores + self.bias
|
| 1149 |
|
| 1150 |
if snark:
|
| 1151 |
+
saveTensor(f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}/gate_original_scores.bin', original_scores.contiguous().cpu())
|
| 1152 |
+
saveTensor(f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}/gate_bias.bin', self.bias.view(torch.uint32).cpu())
|
| 1153 |
|
| 1154 |
# n_groups = 8
|
| 1155 |
if self.n_groups > 1:
|
|
|
|
| 1193 |
weights = original_scores.gather(1, indices)
|
| 1194 |
|
| 1195 |
if snark:
|
| 1196 |
+
saveTensor(f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}/gate_indices.bin', indices.contiguous().cpu())
|
| 1197 |
+
saveTensor(f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}/gate_weights.bin', weights.contiguous().cpu())
|
| 1198 |
|
| 1199 |
# print(f'weights shape: {weights.shape}')
|
| 1200 |
if self.score_func == "sigmoid":
|
|
|
|
| 1266 |
# silu_q25(r1, s1)
|
| 1267 |
|
| 1268 |
if snark:
|
| 1269 |
+
saveTensor(f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}/expert_{self.idx}_silu_x.bin', r1.contiguous().cpu())
|
| 1270 |
|
| 1271 |
silu_q23(r1, s1)
|
| 1272 |
|
| 1273 |
if snark:
|
| 1274 |
+
saveTensor(f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}/expert_{self.idx}_silu_y.bin', s1.cpu())
|
| 1275 |
|
| 1276 |
# r2 rescale: 2^23
|
| 1277 |
r2 = self.w3(x)
|
|
|
|
| 1451 |
(atten_normed, rms) = self.attn_norm(x)
|
| 1452 |
|
| 1453 |
if snark:
|
| 1454 |
+
os.makedirs(f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}', exist_ok=True)
|
| 1455 |
+
saveTensor(f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}/attn_norm_x.bin', x.cpu())
|
| 1456 |
+
saveTensor(f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}/attn_norm_weight.bin', self.attn_norm.weight.view(torch.uint32).cpu())
|
| 1457 |
+
saveTensor(f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}/attn_norm_y.bin', atten_normed.cpu())
|
| 1458 |
+
saveTensor(f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}/attn_norm_rms.bin', rms.cpu())
|
| 1459 |
|
| 1460 |
# attned 的 rescale 是 2^19, shape: [1, seqLen, 7168]
|
| 1461 |
attned = self.attn(atten_normed, start_pos, freqs_cis, mask)
|
inference/runLLM.sh
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
WORLD_SIZE=1 RANK=0 LOCAL_RANK=0 MASTER_ADDR=127.0.0.1 python3 generate.py --ckpt-path /data3/DeepSeek-V3-Demo1 --config configs/config_671B.json --interactive --temperature 1.0 --max-new-tokens 200 > logs/log_$(date +%Y%m%d_%H%M%S).txt 2>&1
|
zk/babel.config.cjs
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
module.exports = {
|
| 2 |
+
presets: [['@babel/preset-env', { targets: { node: 'current' } }]],
|
| 3 |
+
};
|
zk/config.json
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"version": 1,
|
| 3 |
+
"deployAliases": {}
|
| 4 |
+
}
|
zk/package.json
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"name": "dp",
|
| 3 |
+
"version": "0.1.0",
|
| 4 |
+
"description": "",
|
| 5 |
+
"author": "",
|
| 6 |
+
"license": "Apache-2.0",
|
| 7 |
+
"keywords": [
|
| 8 |
+
"mina-zkapp",
|
| 9 |
+
"mina-zk-app",
|
| 10 |
+
"mina-dapp",
|
| 11 |
+
"zkapp"
|
| 12 |
+
],
|
| 13 |
+
"type": "module",
|
| 14 |
+
"main": "build/src/index.js",
|
| 15 |
+
"types": "build/src/index.d.ts",
|
| 16 |
+
"scripts": {
|
| 17 |
+
"build": "tsc",
|
| 18 |
+
"buildw": "tsc --watch",
|
| 19 |
+
"coverage": "node --experimental-vm-modules node_modules/jest/bin/jest.js --coverage",
|
| 20 |
+
"format": "prettier --write --ignore-unknown **/*",
|
| 21 |
+
"test": "npm run build && find build/src -name '*.test.js' -exec node --test {} \\;",
|
| 22 |
+
"testw": "npm run build && find build/src -name '*.test.js' -exec node --test --watch {} \\;",
|
| 23 |
+
"lint": "npx eslint src/* --fix",
|
| 24 |
+
"clear-cache": "npx rimraf cache/* !cache/README.md && npx rimraf cache.json && echo \"Cache cleared successfully!\"",
|
| 25 |
+
"start": "node build/src/run.js"
|
| 26 |
+
},
|
| 27 |
+
"devDependencies": {
|
| 28 |
+
"@babel/preset-env": "^7.16.4",
|
| 29 |
+
"@babel/preset-typescript": "^7.16.0",
|
| 30 |
+
"@types/node": "^22.14.1",
|
| 31 |
+
"@typescript-eslint/eslint-plugin": "^5.5.0",
|
| 32 |
+
"@typescript-eslint/parser": "^5.5.0",
|
| 33 |
+
"eslint": "^8.7.0",
|
| 34 |
+
"eslint-plugin-o1js": "^0.4.0",
|
| 35 |
+
"prettier": "^2.3.2",
|
| 36 |
+
"ts-jest": "^29.2.4",
|
| 37 |
+
"typescript": "^5.4.5",
|
| 38 |
+
"commander": "^14.0.2"
|
| 39 |
+
},
|
| 40 |
+
"peerDependencies": {
|
| 41 |
+
"o1js": "^2.*"
|
| 42 |
+
},
|
| 43 |
+
"engines": {
|
| 44 |
+
"node": ">=18.14.0"
|
| 45 |
+
}
|
| 46 |
+
}
|
zk/runZK.py
ADDED
|
@@ -0,0 +1,408 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
zkDataDir = '../zkdata'
|
| 5 |
+
|
| 6 |
+
async def main():
|
| 7 |
+
sem1 = asyncio.Semaphore(1)
|
| 8 |
+
sem7 = asyncio.Semaphore(7)
|
| 9 |
+
sem8 = asyncio.Semaphore(8)
|
| 10 |
+
sem32 = asyncio.Semaphore(32)
|
| 11 |
+
|
| 12 |
+
async def taskEmbed():
|
| 13 |
+
print(f'taskEmbed')
|
| 14 |
+
|
| 15 |
+
fEmbed = open("embed.log", "a", buffering=1)
|
| 16 |
+
fEmbedErr = open("embedErr.log", "w", buffering=1)
|
| 17 |
+
|
| 18 |
+
data = np.fromfile(f"{zkDataDir}/pos_0/tokens.bin", dtype=np.int64)
|
| 19 |
+
print('xs: ', data)
|
| 20 |
+
dataLen = len(data)
|
| 21 |
+
|
| 22 |
+
# 计算 所有 vocabulary embedding 的 hash
|
| 23 |
+
async def computeHash(tokenId):
|
| 24 |
+
async with sem32:
|
| 25 |
+
p = await asyncio.create_subprocess_exec("bash", "-lc", f'node build/src/index.js computeHash embed {tokenId}',
|
| 26 |
+
stdout=fEmbed, stderr=fEmbedErr)
|
| 27 |
+
rc = await p.wait()
|
| 28 |
+
return (tokenId, rc)
|
| 29 |
+
|
| 30 |
+
results = await asyncio.gather(*(computeHash(i) for i in range(0, 129280) ))
|
| 31 |
+
|
| 32 |
+
# 汇集所有的 vocabulary embedding 到 hashTable.json 中
|
| 33 |
+
p = await asyncio.create_subprocess_exec("bash", "-lc", f'node build/src/index.js precomputeHashes embed',
|
| 34 |
+
stdout=fEmbed, stderr=fEmbedErr)
|
| 35 |
+
rc = await p.wait()
|
| 36 |
+
|
| 37 |
+
# 计算 tokens 的 root hash
|
| 38 |
+
p = await asyncio.create_subprocess_exec("bash", "-lc", f'node build/src/index.js computeEmbedHash embed',
|
| 39 |
+
stdout=fEmbed, stderr=fEmbedErr)
|
| 40 |
+
rc = await p.wait()
|
| 41 |
+
|
| 42 |
+
async def taskEmbedBase(rowId):
|
| 43 |
+
async with sem7:
|
| 44 |
+
p = await asyncio.create_subprocess_exec("bash", "-lc", f'node build/src/index.js embedSectionBase embed {rowId}',
|
| 45 |
+
stdout=fEmbed, stderr=fEmbedErr)
|
| 46 |
+
rc = await p.wait()
|
| 47 |
+
return (rowId, rc)
|
| 48 |
+
|
| 49 |
+
results = await asyncio.gather(*(taskEmbedBase(i) for i in range(0, dataLen) ))
|
| 50 |
+
|
| 51 |
+
async def taskEmbedMerge(rowId):
|
| 52 |
+
async with sem7:
|
| 53 |
+
p = await asyncio.create_subprocess_exec("bash", "-lc", f'node build/src/index.js embedSectionMerge embed {rowId}',
|
| 54 |
+
stdout=fEmbed, stderr=fEmbedErr)
|
| 55 |
+
rc = await p.wait()
|
| 56 |
+
return (rowId, rc)
|
| 57 |
+
|
| 58 |
+
results = await asyncio.gather(*(taskEmbedMerge(i) for i in range(0, dataLen) ))
|
| 59 |
+
|
| 60 |
+
async def taskEmbedRowsMerge():
|
| 61 |
+
async with sem7:
|
| 62 |
+
p = await asyncio.create_subprocess_exec("bash", "-lc", f'node build/src/index.js embedRowsMerge embed',
|
| 63 |
+
stdout=fEmbed, stderr=fEmbedErr)
|
| 64 |
+
rc = await p.wait()
|
| 65 |
+
return rc
|
| 66 |
+
|
| 67 |
+
results = await asyncio.gather((taskEmbedRowsMerge() ))
|
| 68 |
+
|
| 69 |
+
fEmbed.close()
|
| 70 |
+
fEmbedErr.close()
|
| 71 |
+
|
| 72 |
+
async def taskAttnNorm(name, posId, layerId):
|
| 73 |
+
print(f'taskAttnNorm {name}')
|
| 74 |
+
|
| 75 |
+
fLog = open(f"{name}_Norm.log", "a", buffering=1)
|
| 76 |
+
fErr = open(f"{name}_NormErr.log", "w", buffering=1)
|
| 77 |
+
|
| 78 |
+
data = np.fromfile(f"{zkDataDir}/pos_0/tokens.bin", dtype=np.int64)
|
| 79 |
+
print('xs: ', data)
|
| 80 |
+
dataLen = len(data)
|
| 81 |
+
|
| 82 |
+
async def taskAttnNormBase(rowId, ind):
|
| 83 |
+
async with sem7:
|
| 84 |
+
p = await asyncio.create_subprocess_exec("bash", "-lc", f'node build/src/index.js normBase {name} {posId} {layerId} {rowId} {ind}',
|
| 85 |
+
stdout=fLog, stderr=fErr)
|
| 86 |
+
rc = await p.wait()
|
| 87 |
+
return (rowId, rc)
|
| 88 |
+
|
| 89 |
+
if name == 'attn_norm':
|
| 90 |
+
results = await asyncio.gather(*(taskAttnNormBase(i, j) for i in range(0, 24) for j in (0, 32)))
|
| 91 |
+
elif name == 'q_norm':
|
| 92 |
+
results = await asyncio.gather(*(taskAttnNormBase(i, 0) for i in range(0, 24)))
|
| 93 |
+
|
| 94 |
+
async def taskAttnNormMerge(rowId, startIdx):
|
| 95 |
+
async with sem8:
|
| 96 |
+
rc = 0
|
| 97 |
+
for j in range(startIdx, 0, -8):
|
| 98 |
+
p = await asyncio.create_subprocess_exec("bash", "-lc", f'node build/src/index.js normMerge {name} {posId} {layerId} {rowId} {j}',
|
| 99 |
+
stdout=fLog, stderr=fErr)
|
| 100 |
+
rc = await p.wait()
|
| 101 |
+
return (rowId, rc)
|
| 102 |
+
|
| 103 |
+
if name == 'attn_norm':
|
| 104 |
+
results = await asyncio.gather(*(taskAttnNormMerge(i, 62) for i in range(0, 24)))
|
| 105 |
+
elif name == 'q_norm':
|
| 106 |
+
results = await asyncio.gather(*(taskAttnNormMerge(i, 30) for i in range(0, 24)))
|
| 107 |
+
|
| 108 |
+
async def normWrapRow():
|
| 109 |
+
async with sem7:
|
| 110 |
+
p = await asyncio.create_subprocess_exec("bash", "-lc", f'node build/src/index.js normWrapRow {name} {posId} {layerId}',
|
| 111 |
+
stdout=fLog, stderr=fErr)
|
| 112 |
+
rc = await p.wait()
|
| 113 |
+
return rc
|
| 114 |
+
|
| 115 |
+
results = await asyncio.gather(normWrapRow())
|
| 116 |
+
|
| 117 |
+
async def normMergeRow():
|
| 118 |
+
async with sem7:
|
| 119 |
+
p = await asyncio.create_subprocess_exec("bash", "-lc", f'node build/src/index.js normMergeRow {name} {posId} {layerId}',
|
| 120 |
+
stdout=fLog, stderr=fErr)
|
| 121 |
+
rc = await p.wait()
|
| 122 |
+
return rc
|
| 123 |
+
|
| 124 |
+
results = await asyncio.gather(normMergeRow())
|
| 125 |
+
|
| 126 |
+
fLog.close()
|
| 127 |
+
fErr.close()
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
# gate 中 experts 选择逻辑
|
| 131 |
+
async def taskExpertSelector(name, posId, layerId):
|
| 132 |
+
print(f'taskGateExpertSelector')
|
| 133 |
+
|
| 134 |
+
fgate = open(f"{name}_expertSelector.log", "a", buffering=1)
|
| 135 |
+
fgateErr = open(f"{name}_expertSelectorErr.log", "w", buffering=1)
|
| 136 |
+
|
| 137 |
+
async def taskGroupBase(rowId):
|
| 138 |
+
async with sem8:
|
| 139 |
+
p = await asyncio.create_subprocess_exec("bash", "-lc", f'node build/src/index.js expertsGroupBase {name} {posId} {layerId} {rowId}',
|
| 140 |
+
stdout=fgate, stderr=fgateErr)
|
| 141 |
+
rc = await p.wait()
|
| 142 |
+
return (rowId, rc)
|
| 143 |
+
|
| 144 |
+
results = await asyncio.gather(*(taskGroupBase(i) for i in range(0, 24) ))
|
| 145 |
+
|
| 146 |
+
async def taskGroupMerge(rowId):
|
| 147 |
+
async with sem8:
|
| 148 |
+
p = await asyncio.create_subprocess_exec("bash", "-lc", f'node build/src/index.js expertsGroupMerge {name} {posId} {layerId} {rowId}',
|
| 149 |
+
stdout=fgate, stderr=fgateErr)
|
| 150 |
+
rc = await p.wait()
|
| 151 |
+
return (rowId, rc)
|
| 152 |
+
|
| 153 |
+
results = await asyncio.gather(*(taskGroupMerge(i) for i in range(0, 24) ))
|
| 154 |
+
|
| 155 |
+
async def taskSortedGroupBase(rowId):
|
| 156 |
+
async with sem8:
|
| 157 |
+
p = await asyncio.create_subprocess_exec("bash", "-lc", f'node build/src/index.js expertsSortedGroupBase {name} {posId} {layerId} {rowId}',
|
| 158 |
+
stdout=fgate, stderr=fgateErr)
|
| 159 |
+
rc = await p.wait()
|
| 160 |
+
return (rowId, rc)
|
| 161 |
+
|
| 162 |
+
results = await asyncio.gather(*(taskSortedGroupBase(i) for i in range(0, 24) ))
|
| 163 |
+
|
| 164 |
+
async def taskSortedGroupMerge(rowId):
|
| 165 |
+
async with sem8:
|
| 166 |
+
p = await asyncio.create_subprocess_exec("bash", "-lc", f'node build/src/index.js expertsSortedGroupMerge {name}s {posId} {layerId} {rowId}',
|
| 167 |
+
stdout=fgate, stderr=fgateErr)
|
| 168 |
+
rc = await p.wait()
|
| 169 |
+
return (rowId, rc)
|
| 170 |
+
|
| 171 |
+
results = await asyncio.gather(*(taskSortedGroupMerge(i) for i in range(0, 24) ))
|
| 172 |
+
|
| 173 |
+
async def taskSelectorBase(rowId):
|
| 174 |
+
async with sem8:
|
| 175 |
+
p = await asyncio.create_subprocess_exec("bash", "-lc", f'node build/src/index.js expertsSelectorBase {name} {posId} {layerId} {rowId}',
|
| 176 |
+
stdout=fgate, stderr=fgateErr)
|
| 177 |
+
rc = await p.wait()
|
| 178 |
+
return (rowId, rc)
|
| 179 |
+
|
| 180 |
+
results = await asyncio.gather(*(taskSelectorBase(i) for i in range(0, 24) ))
|
| 181 |
+
|
| 182 |
+
async def taskSelectorMerge():
|
| 183 |
+
async with sem8:
|
| 184 |
+
p = await asyncio.create_subprocess_exec("bash", "-lc", f'node build/src/index.js expertsSelectorMerge {name} {posId} {layerId}',
|
| 185 |
+
stdout=fgate, stderr=fgateErr)
|
| 186 |
+
rc = await p.wait()
|
| 187 |
+
return rc
|
| 188 |
+
|
| 189 |
+
results = await asyncio.gather((taskSelectorMerge() ))
|
| 190 |
+
|
| 191 |
+
fgate.close()
|
| 192 |
+
fgateErr.close()
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
async def taskRope_pe():
|
| 196 |
+
print(f'taskRope')
|
| 197 |
+
|
| 198 |
+
fLog = open("rope.log", "a", buffering=1)
|
| 199 |
+
fErr = open("ropeErr.log", "w", buffering=1)
|
| 200 |
+
|
| 201 |
+
data = np.fromfile(f"{zkDataDir}/pos_0/tokens.bin", dtype=np.int64)
|
| 202 |
+
print('xs: ', data)
|
| 203 |
+
dataLen = len(data)
|
| 204 |
+
|
| 205 |
+
async def ropeBase(name, posId, layerId, rowId, ind, f_out, f_err):
|
| 206 |
+
async with sem7:
|
| 207 |
+
p = await asyncio.create_subprocess_exec("bash", "-lc", f'node build/src/index.js ropeBase {name} {posId} {layerId} {rowId} {ind}',
|
| 208 |
+
stdout=f_out, stderr=f_err)
|
| 209 |
+
rc = await p.wait()
|
| 210 |
+
return rc
|
| 211 |
+
|
| 212 |
+
results = await asyncio.gather(*(ropeBase('q_pe', 0, 0, i, j, fLog, fErr) for i in range(0, 24) for j in (0, 32, 64, 96) ))
|
| 213 |
+
|
| 214 |
+
async def ropeMerge(name, posId, layerId, rowId, ind, f_out, f_err):
|
| 215 |
+
async with sem8:
|
| 216 |
+
p = await asyncio.create_subprocess_exec("bash", "-lc", f'node build/src/index.js ropeMerge {name} {posId} {layerId} {rowId} {ind}',
|
| 217 |
+
stdout=f_out, stderr=f_err)
|
| 218 |
+
rc = await p.wait()
|
| 219 |
+
return rc
|
| 220 |
+
|
| 221 |
+
for j in range(126, -1, -8):
|
| 222 |
+
results = await asyncio.gather(*(ropeMerge('q_pe', 0, 0, i, j, fLog, fErr) for i in range(0, 24) ))
|
| 223 |
+
|
| 224 |
+
async def wrapRopeRow(name, posId, layerId, f_out, f_err):
|
| 225 |
+
p = await asyncio.create_subprocess_exec("bash", "-lc", f'node build/src/index.js wrapRopeRow {name} {posId} {layerId}',
|
| 226 |
+
stdout=f_out, stderr=f_err)
|
| 227 |
+
rc = await p.wait()
|
| 228 |
+
return rc
|
| 229 |
+
|
| 230 |
+
results = await asyncio.gather(wrapRopeRow('q_pe', 0, 0, fLog, fErr))
|
| 231 |
+
|
| 232 |
+
async def mergeRopeRow(name, posId, layerId, f_out, f_err):
|
| 233 |
+
p = await asyncio.create_subprocess_exec("bash", "-lc", f'node build/src/index.js mergeRopeRow {name} {posId} {layerId}',
|
| 234 |
+
stdout=f_out, stderr=f_err)
|
| 235 |
+
rc = await p.wait()
|
| 236 |
+
return rc
|
| 237 |
+
|
| 238 |
+
results = await asyncio.gather(mergeRopeRow('q_pe', 0, 0, fLog, fErr))
|
| 239 |
+
|
| 240 |
+
fLog.close()
|
| 241 |
+
fErr.close()
|
| 242 |
+
|
| 243 |
+
async def taskSoftmax(name):
|
| 244 |
+
print(f'taskSoftmax {name}')
|
| 245 |
+
|
| 246 |
+
fLog = open(f"{name}_softmax.log", "a", buffering=1)
|
| 247 |
+
fErr = open(f"{name}_softmaxErr.log", "w", buffering=1)
|
| 248 |
+
|
| 249 |
+
async def softmaxHeadBase(posId, layerId, rowId, headId, headDim):
|
| 250 |
+
async with sem7:
|
| 251 |
+
p = await asyncio.create_subprocess_exec(
|
| 252 |
+
"bash", "-lc",
|
| 253 |
+
f'node build/src/index.js softmaxHeadBase {name} {posId} {layerId} {rowId} {headId} {headDim}',
|
| 254 |
+
stdout=fLog, stderr=fErr)
|
| 255 |
+
rc = await p.wait()
|
| 256 |
+
return rc
|
| 257 |
+
results = await asyncio.gather(*(softmaxHeadBase(0, 0, i, j, 24) for i in range(0, 24) for j in range(0, 128, 4)))
|
| 258 |
+
|
| 259 |
+
async def softmaxHeadMerge(posId, layerId, rowId, headDim):
|
| 260 |
+
async with sem8:
|
| 261 |
+
rc = 0
|
| 262 |
+
for ind in range(126, -1, -8):
|
| 263 |
+
p = await asyncio.create_subprocess_exec(
|
| 264 |
+
"bash", "-lc",
|
| 265 |
+
f'node build/src/index.js softmaxHeadMerge {name} {posId} {layerId} {rowId} {ind} {headDim}',
|
| 266 |
+
stdout=fLog, stderr=fErr)
|
| 267 |
+
rc = await p.wait()
|
| 268 |
+
return (rowId, rc)
|
| 269 |
+
results = await asyncio.gather(*(softmaxHeadMerge(0, 0, i, 24) for i in range(0, 24)))
|
| 270 |
+
|
| 271 |
+
async def softmaxWrapRow(posId, layerId):
|
| 272 |
+
p = await asyncio.create_subprocess_exec("bash", "-lc", f'node build/src/index.js softmaxWrapRow {name} {posId} {layerId}',
|
| 273 |
+
stdout=fLog, stderr=fErr)
|
| 274 |
+
rc = await p.wait()
|
| 275 |
+
return rc
|
| 276 |
+
results = await asyncio.gather(softmaxWrapRow(0, 0))
|
| 277 |
+
|
| 278 |
+
async def softmaxMergeRow(posId, layerId):
|
| 279 |
+
p = await asyncio.create_subprocess_exec("bash", "-lc", f'node build/src/index.js softmaxMergeRow {name} {posId} {layerId}',
|
| 280 |
+
stdout=fLog, stderr=fErr)
|
| 281 |
+
rc = await p.wait()
|
| 282 |
+
return rc
|
| 283 |
+
results = await asyncio.gather(softmaxMergeRow(0, 0))
|
| 284 |
+
|
| 285 |
+
fLog.close()
|
| 286 |
+
fErr.close()
|
| 287 |
+
|
| 288 |
+
async def taskSigmoid(name):
|
| 289 |
+
print(f'taskSigmoid {name}')
|
| 290 |
+
|
| 291 |
+
fLog = open(f"{name}_sigmoid.log", "a", buffering=1)
|
| 292 |
+
fErr = open(f"{name}_sigmoidErr.log", "w", buffering=1)
|
| 293 |
+
|
| 294 |
+
async def sigmoidSectionBase(posId, layerId, rowId):
|
| 295 |
+
async with sem8:
|
| 296 |
+
p = await asyncio.create_subprocess_exec("bash", "-lc", f'node build/src/index.js sigmoidSectionBase {name} {posId} {layerId} {rowId}',
|
| 297 |
+
stdout=fLog, stderr=fErr)
|
| 298 |
+
rc = await p.wait()
|
| 299 |
+
return rc
|
| 300 |
+
|
| 301 |
+
results = await asyncio.gather(*(sigmoidSectionBase(0, 3, i) for i in range(0, 24) ))
|
| 302 |
+
|
| 303 |
+
async def sigmoidSectionMerge(posId, layerId, rowId):
|
| 304 |
+
async with sem8:
|
| 305 |
+
p = await asyncio.create_subprocess_exec("bash", "-lc", f'node build/src/index.js sigmoidSectionMerge {name} {posId} {layerId} {rowId}',
|
| 306 |
+
stdout=fLog, stderr=fErr)
|
| 307 |
+
rc = await p.wait()
|
| 308 |
+
return rc
|
| 309 |
+
|
| 310 |
+
results = await asyncio.gather(*(sigmoidSectionMerge(0, 3, i) for i in range(0, 24) ))
|
| 311 |
+
|
| 312 |
+
async def sigmoidRowBase(posId, layerId):
|
| 313 |
+
p = await asyncio.create_subprocess_exec("bash", "-lc", f'node build/src/index.js sigmoidRowBase {name} {posId} {layerId}',
|
| 314 |
+
stdout=fLog, stderr=fErr)
|
| 315 |
+
rc = await p.wait()
|
| 316 |
+
return rc
|
| 317 |
+
results = await asyncio.gather(sigmoidRowBase(0, 3))
|
| 318 |
+
|
| 319 |
+
async def sigmoidRowMerge(posId, layerId):
|
| 320 |
+
p = await asyncio.create_subprocess_exec("bash", "-lc", f'node build/src/index.js sigmoidRowMerge {name} {posId} {layerId}',
|
| 321 |
+
stdout=fLog, stderr=fErr)
|
| 322 |
+
rc = await p.wait()
|
| 323 |
+
return rc
|
| 324 |
+
results = await asyncio.gather(sigmoidRowMerge(0, 3))
|
| 325 |
+
|
| 326 |
+
fLog.close()
|
| 327 |
+
fErr.close()
|
| 328 |
+
|
| 329 |
+
async def taskGemm(name, posId, layerId, InDim, OutDim, ShortDim):
|
| 330 |
+
print(f'taskGemm {name}')
|
| 331 |
+
|
| 332 |
+
fLog = open(f"{name}_gemm.log", "a", buffering=1)
|
| 333 |
+
fErr = open(f"{name}_gemmErr.log", "w", buffering=1)
|
| 334 |
+
|
| 335 |
+
data = np.fromfile(f"{zkDataDir}/pos_{posId}/tokens.bin", dtype=np.int64)
|
| 336 |
+
print('xs: ', data)
|
| 337 |
+
rowCount = len(data)
|
| 338 |
+
|
| 339 |
+
segmentCount = InDim // ShortDim
|
| 340 |
+
startIndArr = [i * 32 for i in range(0, segmentCount // 32)]
|
| 341 |
+
print('startIndArr: ', startIndArr)
|
| 342 |
+
|
| 343 |
+
async def gemmXBase(rowId, ind):
|
| 344 |
+
async with sem8:
|
| 345 |
+
p = await asyncio.create_subprocess_exec("bash", "-lc", f'node build/src/index.js gemmXBase {name} {posId} {layerId} {rowId} {ind}',
|
| 346 |
+
stdout=fLog, stderr=fErr)
|
| 347 |
+
rc = await p.wait()
|
| 348 |
+
return rc
|
| 349 |
+
results = await asyncio.gather(*(gemmXBase(i, j) for i in range(0, rowCount) for j in startIndArr))
|
| 350 |
+
|
| 351 |
+
async def gemmXMergeRow(ind):
|
| 352 |
+
async with sem8:
|
| 353 |
+
p = await asyncio.create_subprocess_exec("bash", "-lc", f'node build/src/index.js gemmXMergeRow {name} {posId} {layerId} {ind}',
|
| 354 |
+
stdout=fLog, stderr=fErr)
|
| 355 |
+
rc = await p.wait()
|
| 356 |
+
return rc
|
| 357 |
+
results = await asyncio.gather(*(gemmXMergeRow(j) for j in range(segmentCount - 1, 2 * segmentCount - 1)))
|
| 358 |
+
|
| 359 |
+
async def gemmWBase(rowId, ind):
|
| 360 |
+
async with sem8:
|
| 361 |
+
p = await asyncio.create_subprocess_exec("bash", "-lc", f'node build/src/index.js gemmWBase {name} {posId} {layerId} {rowId} {ind}',
|
| 362 |
+
stdout=fLog, stderr=fErr)
|
| 363 |
+
rc = await p.wait()
|
| 364 |
+
return rc
|
| 365 |
+
results = await asyncio.gather(*(gemmWBase(i, j) for i in range(0, OutDim) for j in startIndArr ))
|
| 366 |
+
|
| 367 |
+
async def gemmWMergeRow(ind):
|
| 368 |
+
async with sem8:
|
| 369 |
+
rc = 0
|
| 370 |
+
for rowIndex in range(1, 512, 32):
|
| 371 |
+
p = await asyncio.create_subprocess_exec("bash", "-lc", f'node build/src/index.js gemmWMergeRow {name} {posId} {layerId} {ind} {rowIndex}',
|
| 372 |
+
stdout=fLog, stderr=fErr)
|
| 373 |
+
rc = await p.wait()
|
| 374 |
+
return rc
|
| 375 |
+
results = await asyncio.gather(*(gemmWMergeRow(j) for j in range(segmentCount - 1, 2 * segmentCount - 1) ))
|
| 376 |
+
|
| 377 |
+
async def gemmXWBase(ind):
|
| 378 |
+
async with sem8:
|
| 379 |
+
p = await asyncio.create_subprocess_exec("bash", "-lc", f'node build/src/index.js gemmXWBase {name} {posId} {layerId} {ind}',
|
| 380 |
+
stdout=fLog, stderr=fErr)
|
| 381 |
+
rc = await p.wait()
|
| 382 |
+
return rc
|
| 383 |
+
results = await asyncio.gather(*(gemmXWBase(i) for i in startIndArr))
|
| 384 |
+
|
| 385 |
+
async def gemmXWMerge(ind):
|
| 386 |
+
async with sem1:
|
| 387 |
+
p = await asyncio.create_subprocess_exec("bash", "-lc", f'node build/src/index.js gemmXWMerge {name} {posId} {layerId} {ind}',
|
| 388 |
+
stdout=fLog, stderr=fErr)
|
| 389 |
+
rc = await p.wait()
|
| 390 |
+
return rc
|
| 391 |
+
results = await asyncio.gather(*(gemmXWMerge(i) for i in range(segmentCount - 2, -1, -8)))
|
| 392 |
+
|
| 393 |
+
fLog.close()
|
| 394 |
+
fErr.close()
|
| 395 |
+
|
| 396 |
+
# await taskExpertSelector_gate(0, 4)
|
| 397 |
+
await taskEmbed()
|
| 398 |
+
# await taskAttnNorm('attn_norm', 0, 0)
|
| 399 |
+
# await taskAttnNorm('q_norm')
|
| 400 |
+
# await taskRope_pe()
|
| 401 |
+
# await taskSoftmax('scores')
|
| 402 |
+
# await taskSigmoid('gate')
|
| 403 |
+
# await taskExpertSelector('gate', 0, 3)
|
| 404 |
+
# await taskGemm('wkv_a1', 0, 0, 7168, 512, 112)
|
| 405 |
+
|
| 406 |
+
print("all done.")
|
| 407 |
+
|
| 408 |
+
asyncio.run(main())
|
zk/src/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
zk/src/index.ts
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
zk/tsconfig.json
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"compilerOptions": {
|
| 3 |
+
"target": "es2021", // goal: ship *the most modern syntax* that is supported by *all* browsers that support our Wasm
|
| 4 |
+
"module": "nodenext", // allow top-level await
|
| 5 |
+
"lib": ["dom", "esnext"],
|
| 6 |
+
"outDir": "./build",
|
| 7 |
+
"rootDir": ".",
|
| 8 |
+
"strict": true,
|
| 9 |
+
"strictPropertyInitialization": false, // to enable generic constructors, e.g. on CircuitValue
|
| 10 |
+
"skipLibCheck": true,
|
| 11 |
+
"forceConsistentCasingInFileNames": true,
|
| 12 |
+
"esModuleInterop": true,
|
| 13 |
+
"moduleResolution": "nodenext", // comply with node + "type": "module"
|
| 14 |
+
"experimentalDecorators": true, // needed for decorators used in o1js
|
| 15 |
+
"emitDecoratorMetadata": true, // needed for decorators used in o1js
|
| 16 |
+
"allowJs": true, // to use JSDoc in some places where TS would be too cumbersome
|
| 17 |
+
"declaration": true,
|
| 18 |
+
"sourceMap": true,
|
| 19 |
+
"noFallthroughCasesInSwitch": true,
|
| 20 |
+
"allowSyntheticDefaultImports": true,
|
| 21 |
+
"useDefineForClassFields": false, // ensure correct behaviour of class fields with decorators
|
| 22 |
+
"importHelpers": true, // bundle optimization to reduce size
|
| 23 |
+
"baseUrl": "." // base directory for module resolution
|
| 24 |
+
},
|
| 25 |
+
"include": ["./src"]
|
| 26 |
+
}
|