Spaces:
Runtime error
Runtime error
Commit
·
729a063
1
Parent(s):
3a743e7
Update app.py
Browse files
app.py
CHANGED
|
@@ -76,13 +76,13 @@ st.caption("Multi-Head Attention")
|
|
| 76 |
mha_flop = 2*bs*h*n*(d/h)
|
| 77 |
mha_bytes = 2*bs*h*n + 2*bs*h*n*(d/h) + 2*bs*h*(d/h)
|
| 78 |
c1, c2 = st.columns([2, 3])
|
| 79 |
-
|
| 80 |
|
| 81 |
st.caption("Multi-Query Attention")
|
| 82 |
mqa_flop = 2*bs*h*n*(d/h)
|
| 83 |
mqa_bytes = 2*bs*n*(d/h) + 2*bs*n*(d/h) + 2*bs*h*(d/h)
|
| 84 |
c1, c2 = st.columns([2, 3])
|
| 85 |
-
|
| 86 |
|
| 87 |
st.subheader('Output projection')
|
| 88 |
out_flop = 2*bs*1*d*d
|
|
@@ -91,15 +91,17 @@ c1, c2 = st.columns([2, 3])
|
|
| 91 |
out_time = print_kernel_execution(c1, c2, out_flop, out_bytes)
|
| 92 |
|
| 93 |
st.subheader('Element-wise ops')
|
| 94 |
-
st.write("We also need to take into the softmax layer and
|
| 95 |
|
| 96 |
st.caption("Softmax")
|
| 97 |
softmax_bytes = 2*bs*h*n + 2*bs*h*n
|
| 98 |
c1, c2 = st.columns([2, 3])
|
| 99 |
softmax_time = print_kernel_execution(c1, c2, 0, softmax_bytes)
|
| 100 |
|
| 101 |
-
st.caption("Layer norm")
|
| 102 |
-
|
|
|
|
|
|
|
| 103 |
|
| 104 |
st.header('MLP')
|
| 105 |
st.subheader('First Linear')
|
|
@@ -113,3 +115,17 @@ mlp2_flop = 2*bs*1*d*4*d
|
|
| 113 |
mlp2_bytes = 2*bs*1*d + 2*d*4*d + 2*bs*1*4*d
|
| 114 |
c1, c2 = st.columns([2, 3])
|
| 115 |
mlp2_time = print_kernel_execution(c1, c2, mlp2_flop, mlp2_bytes)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
mha_flop = 2*bs*h*n*(d/h)
|
| 77 |
mha_bytes = 2*bs*h*n + 2*bs*h*n*(d/h) + 2*bs*h*(d/h)
|
| 78 |
c1, c2 = st.columns([2, 3])
|
| 79 |
+
att2_mha_time = print_kernel_execution(c1, c2, mha_flop, mha_bytes)
|
| 80 |
|
| 81 |
st.caption("Multi-Query Attention")
|
| 82 |
mqa_flop = 2*bs*h*n*(d/h)
|
| 83 |
mqa_bytes = 2*bs*n*(d/h) + 2*bs*n*(d/h) + 2*bs*h*(d/h)
|
| 84 |
c1, c2 = st.columns([2, 3])
|
| 85 |
+
att2_mqa_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
|
| 86 |
|
| 87 |
st.subheader('Output projection')
|
| 88 |
out_flop = 2*bs*1*d*d
|
|
|
|
| 91 |
out_time = print_kernel_execution(c1, c2, out_flop, out_bytes)
|
| 92 |
|
| 93 |
st.subheader('Element-wise ops')
|
| 94 |
+
st.write("We also need to take into the softmax layer, layer norm, and residual connection. We assume that these operations are memory bound. ")
|
| 95 |
|
| 96 |
st.caption("Softmax")
|
| 97 |
softmax_bytes = 2*bs*h*n + 2*bs*h*n
|
| 98 |
c1, c2 = st.columns([2, 3])
|
| 99 |
softmax_time = print_kernel_execution(c1, c2, 0, softmax_bytes)
|
| 100 |
|
| 101 |
+
st.caption("Layer norm/residual connection")
|
| 102 |
+
ln_bytes = 2*bs*1*d
|
| 103 |
+
ln_flop = 0
|
| 104 |
+
ln_time = print_kernel_execution(c1, c2, 0, softmax_bytes)
|
| 105 |
|
| 106 |
st.header('MLP')
|
| 107 |
st.subheader('First Linear')
|
|
|
|
| 115 |
mlp2_bytes = 2*bs*1*d + 2*d*4*d + 2*bs*1*4*d
|
| 116 |
c1, c2 = st.columns([2, 3])
|
| 117 |
mlp2_time = print_kernel_execution(c1, c2, mlp2_flop, mlp2_bytes)
|
| 118 |
+
|
| 119 |
+
st.subheader('Element-wise ops')
|
| 120 |
+
st.write("We also need to take into the GeLU, layer norm, and residual connection. We assume that these operations are memory bound. ")
|
| 121 |
+
ln_bytes = 2*bs*1*d
|
| 122 |
+
ln_flop = 0
|
| 123 |
+
ln_time = print_kernel_execution(c1, c2, 0, softmax_bytes)
|
| 124 |
+
|
| 125 |
+
st.header("Adding it all up")
|
| 126 |
+
|
| 127 |
+
shared_time = out_time + softmax_time + 2*ln_time + mlp1_time + mlp2_time + 3*ln_time
|
| 128 |
+
mha_total_time = qkv_mha_time + att1_mha_time + att2_mha_time + shared_time
|
| 129 |
+
mqa_total_time = qkv_mqa_time + att1_mqa_time + att2_mqa_time + shared_time
|
| 130 |
+
st.write("MHA exec time (ms): " + str(mha_total_time))
|
| 131 |
+
st.write("MQA exec time (ms): " + str(mqa_total_time))
|