from state import Model as Model, Parallelism, Training from dtypes import DType class MemoryCalculation: def __init__(self, modelconfig: Model, parallelismconfig: Parallelism, trainingconfig: Training): self.model = modelconfig self.parallelism = parallelismconfig self.training = trainingconfig def calculate_num_parameters(self) -> float: # https://michaelwornow.net/2024/01/18/counting-params-in-transformer # https://huggingface.co/spaces/nanotron/ultrascale-playbook?section=memory_usage_in_transformers # Biases are not added/omitted on a per-model basis for simplicity. # Just include them where they could appear. They're small in comparison to weights anyway and it forms an upper bound. #self tax b, s = self.training.batch_size, self.training.sequence_length h, i, l, v, e = ( self.model.hidden_dim, self.model.intermediate_size, self.model.num_layers, self.model.vocab_size, self.model.total_experts, ) tp, pp, ep = ( self.parallelism.tensor_parallelism, self.parallelism.pipeline_parallelism, self.parallelism.expert_parallelism, ) # Embedding layers input_embedding = v * h / tp unembedding = 0 if not self.model.weight_tied_embeddings: unembedding = h * v / tp # Attention # weights and biases = *2 layer_norm_attn_in = 2 * h # not tp sharded qkv = 3 * h * h / tp attn_output_proj = h * h + h / tp attn = layer_norm_attn_in + qkv + attn_output_proj # MLP layer_norm_mlp_in = 2 * h # not tp sharded router = h * e + e # assuming replicated for simplicity mlp_up_proj = h * i + i / tp mlp_gate_proj = h * i + i / tp mlp_down_proj = i * h + h / tp expert = mlp_up_proj + mlp_gate_proj + mlp_down_proj experts = expert * e / ep mlp = layer_norm_mlp_in + router + experts layer = attn + mlp layers = layer * l final_layer_norm = 2 * h # not tp sharded # pp and weight tying makes knowing where to embed layer challenging # going to assume "worst" case and it's at the end with final layer norm # even though that's pretty small total_params = 0 if pp == 1: total_params = input_embedding + layers + unembedding + final_layer_norm if pp > 1: total_params = max(input_embedding, unembedding) + layers/pp + final_layer_norm return total_params def calculate_activation_parameters(self) -> float: # https://blog.eleuther.ai/transformer-math/#activations-and-batch-size # https://arxiv.org/abs/2205.05198 # pp not considered since most pp schemes will run multiple concurrent batches to reduce the bubble b, s = self.training.batch_size, self.training.sequence_length h, i, l, v, e, ae = ( self.model.hidden_dim, self.model.intermediate_size, self.model.num_layers, self.model.vocab_size, self.model.total_experts, self.model.active_experts, ) tp, cp, pp, ep = ( self.parallelism.tensor_parallelism, self.parallelism.context_parallelism, self.parallelism.pipeline_parallelism, self.parallelism.expert_parallelism, ) sp = tp if self.training.gradient_checkpointing: # full recomputation embed = 0 layer = s * b * h / cp / tp # only keep initial input to layer layers = layer * l embed = 0 final_layer_out = ( s * b * h / cp / sp ) final_norm = s * b * h / cp / sp unembed = s * b * v / cp / tp logits = s * b * v / cp / sp # come back to this num_params = ( embed + layers + final_layer_out + final_norm + unembed + logits ) return num_params else: # assume flash attention ie do selective recomputation # assume tensor parallel + sequence parallel as described in https://arxiv.org/abs/2205.05198 # the variables calculate the activation outputs # Attention Block layer_in = s * b * h / cp / tp attn_norm = s * b * h / cp / sp flash = s * b * h / cp / tp # everything else is recalculated by flash attention projection = s * b * h / cp / tp attn = layer_in + attn_norm + flash + projection # MLP Block mlp_norm = s * b * h / cp / sp mlp_up = s * b * i / cp / tp mlp_gate = s * b * i / cp / tp hadamard_swiglu = s * b * i / cp / tp mlp_down = s * b * h / cp / tp if self.model.is_moe: router = ( s * b * e / cp / sp) # makes sense to sp shard if mlp_norm out is sp sharded expert = mlp_up + mlp_gate + hadamard_swiglu + mlp_down experts = expert * ae mlp = mlp_norm + router + experts else: mlp = mlp_norm + mlp_up + mlp_gate + hadamard_swiglu + mlp_down layer = attn + mlp layers = layer * l # no decrease from PP because schedules will increase microbatches # Other embed = 0 final_layer_out = ( s * b * h / cp / tp ) # both sequence and context parallelism final_norm = s * b * h / cp / sp unembed = s * b * v / cp / tp logits = s * b * v / cp / tp num_params = ( embed + layers + final_layer_out + final_norm + unembed + logits ) return num_params def calculate_parameter_memory(self) -> float: if self.training.mixed_precision: master_copy = self.calculate_num_parameters() * self.training.precision working_copy = self.calculate_num_parameters() * self.training.param_dtype return master_copy + working_copy else: return self.calculate_num_parameters() * self.training.precision def calculate_gradient_memory(self) -> float: # https://blog.eleuther.ai/transformer-math/#gradients return ( self.calculate_num_parameters() * 4 ) # gradients are same size as parameters def calculate_optimizer_memory(self) -> float: # https://blog.eleuther.ai/transformer-math/#optimizer-states # https://www.determined.ai/blog/act-mem-2, https://web.archive.org/web/20250308172134/https://www.determined.ai/blog/act-mem-2 return ( 2 * self.calculate_num_parameters() * DType.FP32 ) # Adam optimizer with 2 states per parameter, assume always fp32 def calculate_activation_memory(self) -> float: if self.training.mixed_precision: return self.calculate_activation_parameters() * self.training.param_dtype else: return ( self.calculate_activation_parameters() * self.training.precision )