| | def activation_memory( |
| | a, |
| | b, |
| | h, |
| | h_ff, |
| | L, |
| | s, |
| | mixed=True, |
| | recomputation="none" |
| | ): |
| | |
| | |
| | if mixed: |
| | bytes_per_value = 2 |
| | else: |
| | bytes_per_value = 4 |
| |
|
| | one_layer_attention = s * b * h * (bytes_per_value * 5 + 1) + ((2 * bytes_per_value + 1) * a * s * s * b) |
| | one_layer_feedforward_mlp = (s * b * h * bytes_per_value + (s * b * h_ff * bytes_per_value) |
| | + s * b * h_ff * bytes_per_value |
| | + s * b * h) |
| | one_layer_feedforward_swiglu = (s * b * h * bytes_per_value + (s * b * h_ff * bytes_per_value) |
| | + s * b * h_ff * bytes_per_value * 3 |
| | + s * b * h) |
| |
|
| |
|
| | if recomputation == "none": |
| | one_layer = one_layer_attention |
| | elif recomputation =="selective": |
| | one_layer = s * b * h * 34 |
| | elif recomputation =="full": |
| | one_layer = s * b * h * 2 |
| | else: |
| | raise ValueError() |
| | |
| | input_dropout = 0 |
| |
|
| | total = L * one_layer + input_dropout |
| | |
| | return total |
| |
|
| |
|
| | def param_grads_opt( |
| | h, |
| | L, |
| | s, |
| | v, |
| | k=8, |
| | mixed=True |
| | ): |
| | |
| | |
| | |
| | |
| | emb = h*(v+s) |
| | one_layer = 12 * h**2 + 13*h |
| | other = 2*h |
| |
|
| | n = emb + L * one_layer + other |
| | |
| | |
| | |
| | if mixed: |
| | k += 4 |
| | bytes_per_paramter = 2 |
| | else: |
| | bytes_per_paramter = 4 |
| | |
| | return bytes_per_paramter*n, bytes_per_paramter*n, k*n |
| |
|