nouamanetazi HF Staff commited on
Commit
3f50411
·
1 Parent(s): 5357354
Files changed (1) hide show
  1. utils.py +3 -1
utils.py CHANGED
@@ -76,7 +76,9 @@ def calculate_memory_components(
76
  overhead = 72 + 32 * mbs
77
 
78
  # Activations
79
- decoder_layer_mib = (seq_len * mbs * hidden_size/tp) * (2/1024/1024) * (4*intermediate_size/hidden_size + 10)
 
 
80
 
81
  if pp > 1:
82
  activs = min(pp, batch_accum) * num_hidden_layers_in_pp * decoder_layer_mib
 
76
  overhead = 72 + 32 * mbs
77
 
78
  # Activations
79
+ # decoder_layer_mib = (seq_len * mbs * hidden_size/tp) * (2/1024/1024) * (4*intermediate_size/hidden_size + 10)
80
+ is_mha = num_key_value_heads == num_attention_heads
81
+ decoder_layer_mib = (seq_len * mbs * hidden_size/tp) * (2/1024/1024) * (4*intermediate_size/hidden_size + 12 + 2*num_key_value_heads/num_attention_heads + (2 if is_mha else 0))
82
 
83
  if pp > 1:
84
  activs = min(pp, batch_accum) * num_hidden_layers_in_pp * decoder_layer_mib