llm_memory_visualizer / calculator.py
rubenaghayan's picture
finish up calcs
b79954f
raw
history blame
6.79 kB
from state import Model as Model, Parallelism, Training
class MemoryCalculation:
def __init__(self, model: Model, parallelism: Parallelism, training: Training):
self.model = model
self.parallelism = parallelism
self.training = training
def calculate_num_parameters(self) -> float:
# https://michaelwornow.net/2024/01/18/counting-params-in-transformer
# https://huggingface.co/spaces/nanotron/ultrascale-playbook?section=memory_usage_in_transformers
# Biases are not added/omitted on a per-model basis for simplicity.
# Just include them where they could appear. They're small in comparison to weights anyway and it forms an upper bound.
b, s = self.training.batch_size, self.training.sequence_length
h, i, l, v, e = (
self.model.hidden_dim,
self.model.intermediate_size,
self.model.num_layers,
self.model.vocab_size,
self.model.experts,
)
tp, pp, ep = (
self.parallelism.tensor_parallelism,
self.parallelism.pipeline_parallelism,
self.parallelism.expert_parallelism,
)
# Embedding layers
input_embedding = v * h / tp
unembedding = 0
if not self.model.weight_tied_embeddings:
unembedding = h * v / tp
# Attention
# weights and biases = *2
layer_norm_attn_in = 2 * h # not tp sharded
qkv = 3 * h * h / tp
attn_output_proj = h * h + h / tp
attn = layer_norm_attn_in + qkv + attn_output_proj
# MLP
layer_norm_mlp_in = 2 * h # not tp sharded
router = h * e + e # assuming replicated for simplicity
mlp_up_proj = h * i + i / tp
mlp_gate_proj = h * i + i / tp
mlp_down_proj = i * h + h / tp
expert = mlp_up_proj + mlp_gate_proj + mlp_down_proj
experts = expert * e / ep
mlp = layer_norm_mlp_in + router + experts
layer = attn + mlp
layers = layer * l
final_layer_norm = 2 * h # not tp sharded
# pp and weight tying makes knowing where to embed layer challenging
# going to assume "worst" case and it's at the end with final layer norm
# even though that's pretty smalle
if pp == 1:
total_params = input_embedding + layers + unembedding + final_layer_norm
if pp > 1:
total_params = max(input_embedding, unembedding) + layers/pp + final_layer_norm
return total_params
def calculate_parameter_memory(self) -> float:
return (
self.calculate_num_parameters() * 4
) # assuming 4 bytes (32 bits) per parameter
def calculate_activation_parameters(self) -> float:
# https://blog.eleuther.ai/transformer-math/#activations-and-batch-size
# https://arxiv.org/abs/2205.05198
# pp not considered since most pp schemes will run multiple concurrent batches to reduce the bubble
b, s = self.training.batch_size, self.training.sequence_length
h, i, l, v, e, ae = (
self.model.hidden_dim,
self.model.intermediate_size,
self.model.num_layers,
self.model.vocab_size,
self.model.active_experts,
)
tp, cp, pp, ep = (
self.parallelism.tensor_parallelism,
self.parallelism.context_parallelism,
self.parallelism.pipeline_parallelism,
self.parallelism.expert_parallelism,
)
if self.training.gradient_checkpointing:
# full recomputation
embed = 0
layer = s * b * h / cp / tp # only keep initial input to layer
layers = layer * l
embed = 0
final_layer_out = (
s * b * h / cp / tp
) # both sequence and tensor parallelism
final_norm = s * b * h / cp / tp # both sequence and tensor parallelism
unembed = s * b * v / cp / tp
logits = s * b * v / cp / tp # both vocab and tensor parallelism
num_params = (
embed + layers + final_layer_out + final_norm + unembed + logits
)
return num_params
else:
# assume flash attention ie do selective recomputation
# assume tensor parallel + sequence parallel as described in https://arxiv.org/abs/2205.05198
# the variables calculate the activation outputs
# Attention Block
layer_in = s * b * h / cp / tp # both sequence and context parallelism
attn_norm = s * b * h / cp / tp # both sequence and context parallelism
flash = s * b * h / cp / tp
# everything else is recalculated by flash attention
projection = s * b * h / cp / tp
attn = layer_in + attn_norm + flash + projection
# MLP Block
mlp_norm = s * b * h / cp / tp # both sequence and context parallelism
router = (
s * b * e / cp / tp
) # makes sense to sp shard if mlp_norm out is sp sharded
mlp_up = s * b * i / cp / tp
mlp_gate = s * b * i / cp / tp
hadamard_swiglu = s * b * i / cp / tp
mlp_down = s * b * h / cp / tp
expert = mlp_up + mlp_gate + hadamard_swiglu + mlp_down
experts = expert * ae
mlp = mlp_norm + router + experts
layer = attn + mlp
layers = layer * l
# Other
embed = 0
final_layer_out = (
s * b * h / cp / tp
) # both sequence and context parallelism
final_norm = s * b * h / cp / tp # both sequence and context parallelism
unembed = s * b * v / cp / tp
logits = s * b * v / cp / tp
num_params = (
embed + layers + final_layer_out + final_norm + unembed + logits
)
return num_params
def calculate_activation_memory(self) -> float:
return (
self.calculate_activation_parameters() * 4
) # assuming 4 bytes (32 bits) per activation
def calculate_gradient_memory(self) -> float:
# https://blog.eleuther.ai/transformer-math/#gradients
return (
self.calculate_parameter_memory()
) # gradients are same size as parameters
def calculate_optimizer_memory(self) -> float:
# https://blog.eleuther.ai/transformer-math/#optimizer-states
# https://www.determined.ai/blog/act-mem-2, https://web.archive.org/web/20250308172134/https://www.determined.ai/blog/act-mem-2
return (
2 * self.calculate_parameter_memory()
) # Adam optimizer with 3 states per parameter