| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from torch import nn | |
| from torchtitan.tools.logging import logger | |
| def get_nparams_and_flops(model: nn.Module, model_config, seq_len: int) -> tuple[int, int]: | |
| nparams = sum(p.numel() for p in model.parameters()) | |
| nparams_embedding = sum( | |
| sum(p.numel() for p in m.parameters()) | |
| for m in model.children() | |
| if isinstance(m, nn.Embedding) | |
| ) | |
| if hasattr(model_config, "num_heads"): | |
| num_heads = model_config.num_heads | |
| elif hasattr(model_config, "num_attention_heads"): | |
| num_heads = model_config.num_attention_heads | |
| else: | |
| num_heads = 1 | |
| logger.warning("num_heads not found in model_config, defaulting to 1. ") | |
| l, h, q, t = ( | |
| model_config.num_hidden_layers, | |
| num_heads, | |
| model_config.hidden_size // num_heads, | |
| seq_len, | |
| ) | |
| # Reasoning behind the factor of 12 for the self-attention part of the formula: | |
| # 1. each self-attention has 2 matmul in the forward and 4 in the backward (6) | |
| # 2. the flash attention does 1 more matmul recomputation in the backward | |
| # but recomputation should not be counted in calculating MFU (+0) | |
| # 3. each matmul performs 1 multiplication and 1 addition (*2) | |
| # 4. we follow the convention and do not account for sparsity in causal attention | |
| num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t | |
| return nparams, num_flops_per_token | |