Spaces:
Sleeping
Sleeping
| from state import Model as Model, Parallelism, Training | |
| class MemoryCalculation: | |
| def __init__(self, model: Model, parallelism: Parallelism, training: Training): | |
| self.model = model | |
| self.parallelism = parallelism | |
| self.training = training | |
| def calculate_num_parameters(self) -> float: | |
| # https://michaelwornow.net/2024/01/18/counting-params-in-transformer | |
| # https://huggingface.co/spaces/nanotron/ultrascale-playbook?section=memory_usage_in_transformers | |
| # Biases are not added/omitted on a per-model basis for simplicity. | |
| # Just include them where they could appear. They're small in comparison to weights anyway and it forms an upper bound. | |
| b, s = self.training.batch_size, self.training.sequence_length | |
| h, i, l, v, e = ( | |
| self.model.hidden_dim, | |
| self.model.intermediate_size, | |
| self.model.num_layers, | |
| self.model.vocab_size, | |
| self.model.experts, | |
| ) | |
| tp, pp, ep = ( | |
| self.parallelism.tensor_parallelism, | |
| self.parallelism.pipeline_parallelism, | |
| self.parallelism.expert_parallelism, | |
| ) | |
| # Embedding layers | |
| input_embedding = v * h / tp | |
| unembedding = 0 | |
| if not self.model.weight_tied_embeddings: | |
| unembedding = h * v / tp | |
| # Attention | |
| # weights and biases = *2 | |
| layer_norm_attn_in = 2 * h # not tp sharded | |
| qkv = 3 * h * h / tp | |
| attn_output_proj = h * h + h / tp | |
| attn = layer_norm_attn_in + qkv + attn_output_proj | |
| # MLP | |
| layer_norm_mlp_in = 2 * h # not tp sharded | |
| router = h * e + e # assuming replicated for simplicity | |
| mlp_up_proj = h * i + i / tp | |
| mlp_gate_proj = h * i + i / tp | |
| mlp_down_proj = i * h + h / tp | |
| expert = mlp_up_proj + mlp_gate_proj + mlp_down_proj | |
| experts = expert * e / ep | |
| mlp = layer_norm_mlp_in + router + experts | |
| layer = attn + mlp | |
| layers = layer * l | |
| final_layer_norm = 2 * h # not tp sharded | |
| # pp and weight tying makes knowing where to embed layer challenging | |
| # going to assume "worst" case and it's at the end with final layer norm | |
| # even though that's pretty smalle | |
| if pp == 1: | |
| total_params = input_embedding + layers + unembedding + final_layer_norm | |
| if pp > 1: | |
| total_params = max(input_embedding, unembedding) + layers/pp + final_layer_norm | |
| return total_params | |
| def calculate_parameter_memory(self) -> float: | |
| return ( | |
| self.calculate_num_parameters() * 4 | |
| ) # assuming 4 bytes (32 bits) per parameter | |
| def calculate_activation_parameters(self) -> float: | |
| # https://blog.eleuther.ai/transformer-math/#activations-and-batch-size | |
| # https://arxiv.org/abs/2205.05198 | |
| # pp not considered since most pp schemes will run multiple concurrent batches to reduce the bubble | |
| b, s = self.training.batch_size, self.training.sequence_length | |
| h, i, l, v, e, ae = ( | |
| self.model.hidden_dim, | |
| self.model.intermediate_size, | |
| self.model.num_layers, | |
| self.model.vocab_size, | |
| self.model.active_experts, | |
| ) | |
| tp, cp, pp, ep = ( | |
| self.parallelism.tensor_parallelism, | |
| self.parallelism.context_parallelism, | |
| self.parallelism.pipeline_parallelism, | |
| self.parallelism.expert_parallelism, | |
| ) | |
| if self.training.gradient_checkpointing: | |
| # full recomputation | |
| embed = 0 | |
| layer = s * b * h / cp / tp # only keep initial input to layer | |
| layers = layer * l | |
| embed = 0 | |
| final_layer_out = ( | |
| s * b * h / cp / tp | |
| ) # both sequence and tensor parallelism | |
| final_norm = s * b * h / cp / tp # both sequence and tensor parallelism | |
| unembed = s * b * v / cp / tp | |
| logits = s * b * v / cp / tp # both vocab and tensor parallelism | |
| num_params = ( | |
| embed + layers + final_layer_out + final_norm + unembed + logits | |
| ) | |
| return num_params | |
| else: | |
| # assume flash attention ie do selective recomputation | |
| # assume tensor parallel + sequence parallel as described in https://arxiv.org/abs/2205.05198 | |
| # the variables calculate the activation outputs | |
| # Attention Block | |
| layer_in = s * b * h / cp / tp # both sequence and context parallelism | |
| attn_norm = s * b * h / cp / tp # both sequence and context parallelism | |
| flash = s * b * h / cp / tp | |
| # everything else is recalculated by flash attention | |
| projection = s * b * h / cp / tp | |
| attn = layer_in + attn_norm + flash + projection | |
| # MLP Block | |
| mlp_norm = s * b * h / cp / tp # both sequence and context parallelism | |
| router = ( | |
| s * b * e / cp / tp | |
| ) # makes sense to sp shard if mlp_norm out is sp sharded | |
| mlp_up = s * b * i / cp / tp | |
| mlp_gate = s * b * i / cp / tp | |
| hadamard_swiglu = s * b * i / cp / tp | |
| mlp_down = s * b * h / cp / tp | |
| expert = mlp_up + mlp_gate + hadamard_swiglu + mlp_down | |
| experts = expert * ae | |
| mlp = mlp_norm + router + experts | |
| layer = attn + mlp | |
| layers = layer * l | |
| # Other | |
| embed = 0 | |
| final_layer_out = ( | |
| s * b * h / cp / tp | |
| ) # both sequence and context parallelism | |
| final_norm = s * b * h / cp / tp # both sequence and context parallelism | |
| unembed = s * b * v / cp / tp | |
| logits = s * b * v / cp / tp | |
| num_params = ( | |
| embed + layers + final_layer_out + final_norm + unembed + logits | |
| ) | |
| return num_params | |
| def calculate_activation_memory(self) -> float: | |
| return ( | |
| self.calculate_activation_parameters() * 4 | |
| ) # assuming 4 bytes (32 bits) per activation | |
| def calculate_gradient_memory(self) -> float: | |
| # https://blog.eleuther.ai/transformer-math/#gradients | |
| return ( | |
| self.calculate_parameter_memory() | |
| ) # gradients are same size as parameters | |
| def calculate_optimizer_memory(self) -> float: | |
| # https://blog.eleuther.ai/transformer-math/#optimizer-states | |
| # https://www.determined.ai/blog/act-mem-2, https://web.archive.org/web/20250308172134/https://www.determined.ai/blog/act-mem-2 | |
| return ( | |
| 2 * self.calculate_parameter_memory() | |
| ) # Adam optimizer with 3 states per parameter | |