|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from torch import nn |
|
|
from torchtitan.tools.logging import logger |
|
|
|
|
|
|
|
|
def get_nparams_and_flops(model: nn.Module, model_config, seq_len: int) -> tuple[int, int]: |
|
|
nparams = sum(p.numel() for p in model.parameters()) |
|
|
nparams_embedding = sum( |
|
|
sum(p.numel() for p in m.parameters()) |
|
|
for m in model.children() |
|
|
if isinstance(m, nn.Embedding) |
|
|
) |
|
|
|
|
|
if hasattr(model_config, "num_heads"): |
|
|
num_heads = model_config.num_heads |
|
|
elif hasattr(model_config, "num_attention_heads"): |
|
|
num_heads = model_config.num_attention_heads |
|
|
else: |
|
|
num_heads = 1 |
|
|
logger.warning("num_heads not found in model_config, defaulting to 1. ") |
|
|
|
|
|
l, h, q, t = ( |
|
|
model_config.num_hidden_layers, |
|
|
num_heads, |
|
|
model_config.hidden_size // num_heads, |
|
|
seq_len, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t |
|
|
|
|
|
return nparams, num_flops_per_token |
|
|
|