Spaces:
Running
Running
Commit ·
3f50411
1
Parent(s): 5357354
fix gqa
Browse files
utils.py
CHANGED
|
@@ -76,7 +76,9 @@ def calculate_memory_components(
|
|
| 76 |
overhead = 72 + 32 * mbs
|
| 77 |
|
| 78 |
# Activations
|
| 79 |
-
decoder_layer_mib = (seq_len * mbs * hidden_size/tp) * (2/1024/1024) * (4*intermediate_size/hidden_size + 10)
|
|
|
|
|
|
|
| 80 |
|
| 81 |
if pp > 1:
|
| 82 |
activs = min(pp, batch_accum) * num_hidden_layers_in_pp * decoder_layer_mib
|
|
|
|
| 76 |
overhead = 72 + 32 * mbs
|
| 77 |
|
| 78 |
# Activations
|
| 79 |
+
# decoder_layer_mib = (seq_len * mbs * hidden_size/tp) * (2/1024/1024) * (4*intermediate_size/hidden_size + 10)
|
| 80 |
+
is_mha = num_key_value_heads == num_attention_heads
|
| 81 |
+
decoder_layer_mib = (seq_len * mbs * hidden_size/tp) * (2/1024/1024) * (4*intermediate_size/hidden_size + 12 + 2*num_key_value_heads/num_attention_heads + (2 if is_mha else 0))
|
| 82 |
|
| 83 |
if pp > 1:
|
| 84 |
activs = min(pp, batch_accum) * num_hidden_layers_in_pp * decoder_layer_mib
|