|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import List, Tuple |
|
|
|
|
|
import lightning.pytorch as pl |
|
|
import numpy as np |
|
|
import torch |
|
|
from lightning.pytorch.callbacks import Callback |
|
|
|
|
|
from nemo.collections.llm.gpt.model.base import GPTConfig |
|
|
from nemo.lightning.pytorch.callbacks import PEFT |
|
|
from nemo.utils import flops_formulas, logging |
|
|
from nemo.utils.hyena_flops_formulas import hyena |
|
|
|
|
|
|
|
|
__all__ = ["FLOPsMeasurementCallback", "MM_FLOPsMeasurementCallback"] |
|
|
|
|
|
_model_flops_map = { |
|
|
"gpt3": flops_formulas.gpt3, |
|
|
"llama2": flops_formulas.llama2, |
|
|
"llama3": flops_formulas.llama3, |
|
|
"llama4": flops_formulas.llama3, |
|
|
"nemotron3": flops_formulas.nemotron, |
|
|
"nemotron4": flops_formulas.nemotron, |
|
|
"mixtral": flops_formulas.mixtral, |
|
|
"bert": flops_formulas.bert, |
|
|
"hyena": hyena, |
|
|
"deepseekv3": flops_formulas.deepseekv3, |
|
|
"transformer": flops_formulas.transformer, |
|
|
"qwen3": flops_formulas.qwen3, |
|
|
"nemotronh": flops_formulas.nemotronh, |
|
|
} |
|
|
|
|
|
|
|
|
class FLOPsMeasurementCallback(Callback): |
|
|
""" |
|
|
Calculate and log FLOPs per second after every ``trainer.log_every_n_steps`` steps. |
|
|
|
|
|
Args: |
|
|
model_config (GPTConfig): Model parameters. |
|
|
data_config (pl.LightningDataModule): Data module being used in the experiment. |
|
|
model_name (str): Name of the model being run. The following models are supported: |
|
|
gpt3, llama2, llama3, nemotron, mixtral, bert, hyena. |
|
|
|
|
|
|
|
|
""" |
|
|
|
|
|
higher_is_better = True |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model_config: GPTConfig, |
|
|
data_config: pl.LightningDataModule, |
|
|
model_name: str, |
|
|
): |
|
|
self.model_cfg = model_config |
|
|
self.data_cfg = data_config |
|
|
|
|
|
|
|
|
self.model = model_name |
|
|
|
|
|
gbs = self.data_cfg.global_batch_size |
|
|
enc_seq_len = self.model_cfg.seq_length |
|
|
hs = self.model_cfg.hidden_size |
|
|
layers = self.model_cfg.num_layers |
|
|
ffn_hs = self.model_cfg.ffn_hidden_size |
|
|
attention_heads = self.model_cfg.num_attention_heads |
|
|
moe_router_topk = self.model_cfg.moe_router_topk |
|
|
model_pattern = getattr(self.model_cfg, "hybrid_override_pattern", None) |
|
|
vocab_size = self.data_cfg.tokenizer.vocab_size if hasattr(self.data_cfg, "tokenizer") else None |
|
|
|
|
|
|
|
|
query_groups = self.model_cfg.num_query_groups |
|
|
if query_groups is None: |
|
|
query_groups = attention_heads |
|
|
|
|
|
config_kwargs = { |
|
|
"gbs": gbs, |
|
|
"enc_seq_len": enc_seq_len, |
|
|
"hs": hs, |
|
|
"layers": layers, |
|
|
"ffn_hs": ffn_hs, |
|
|
"attention_heads": attention_heads, |
|
|
"moe_router_topk": moe_router_topk, |
|
|
"query_groups": query_groups, |
|
|
"vocab_size": vocab_size, |
|
|
"model_pattern": model_pattern, |
|
|
} |
|
|
|
|
|
from megatron.core.transformer.transformer_config import MLATransformerConfig |
|
|
|
|
|
if isinstance(self.model_cfg, MLATransformerConfig): |
|
|
config_kwargs["qk_head_dim"] = self.model_cfg.qk_head_dim |
|
|
config_kwargs["qk_pos_emb_head_dim"] = self.model_cfg.qk_pos_emb_head_dim |
|
|
config_kwargs["v_head_dim"] = self.model_cfg.v_head_dim |
|
|
config_kwargs["q_lora_rank"] = self.model_cfg.q_lora_rank |
|
|
config_kwargs["kv_lora_rank"] = self.model_cfg.kv_lora_rank |
|
|
config_kwargs["moe_layer_freq"] = self.model_cfg.moe_layer_freq |
|
|
config_kwargs["moe_shared_expert_intermediate_size"] = self.model_cfg.moe_shared_expert_intermediate_size |
|
|
config_kwargs["moe_ffn_hidden_size"] = self.model_cfg.moe_ffn_hidden_size |
|
|
config_kwargs["mtp_num_layers"] = self.model_cfg.mtp_num_layers |
|
|
|
|
|
if self.model_cfg.is_hybrid_model: |
|
|
config_kwargs['is_hybrid_model'] = True |
|
|
config_kwargs['hybrid_override_pattern'] = self.model_cfg.hybrid_override_pattern |
|
|
config_kwargs['mamba_state_dim'] = self.model_cfg.mamba_state_dim |
|
|
config_kwargs['mamba_head_dim'] = self.model_cfg.mamba_head_dim |
|
|
config_kwargs['mamba_num_groups'] = self.model_cfg.mamba_num_groups |
|
|
config_kwargs['mamba_num_heads'] = self.model_cfg.mamba_num_heads |
|
|
|
|
|
self.flops_config = flops_formulas.FLOPSConfig(**config_kwargs) |
|
|
|
|
|
self.model = self.model.lower() if self.model is not None else self.model |
|
|
|
|
|
self.avg_train_step_time = 0 |
|
|
|
|
|
def on_train_start(self, trainer, pl_module): |
|
|
""" |
|
|
PyTorch Lightning callback hook. Ensures that user is not using PEFT |
|
|
as FLOPS callback does not support it. |
|
|
""" |
|
|
for callback in trainer.callbacks: |
|
|
if isinstance(callback, PEFT): |
|
|
raise NotImplementedError("FLOPs measurement not supported for finetuning jobs") |
|
|
|
|
|
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx: int): |
|
|
""" |
|
|
PyTorch Lightning callback hook to calculate TFLOPs per sec per GPU after training |
|
|
""" |
|
|
try: |
|
|
self.avg_train_step_time += trainer.progress_bar_metrics['train_step_timing in s'] |
|
|
except KeyError: |
|
|
print("'train_step_timing in s' not found. Make sure to use TimingCallback with FLOPsMeasurementCallback.") |
|
|
n = trainer.strategy.current_epoch_step |
|
|
if n % trainer.log_every_n_steps == 0: |
|
|
|
|
|
if self.avg_train_step_time == 0: |
|
|
return |
|
|
train_step_time = self.avg_train_step_time / trainer.log_every_n_steps |
|
|
tflops_per_gpu, flops = self.eval_tflops_per_sec_per_gpu(train_step_time) |
|
|
self.avg_train_step_time = 0 |
|
|
pl_module.log( |
|
|
"TFLOPS_per_GPU", |
|
|
tflops_per_gpu, |
|
|
on_step=True, |
|
|
on_epoch=False, |
|
|
batch_size=1, |
|
|
prog_bar=True, |
|
|
) |
|
|
|
|
|
tflops = flops / (1e12 * train_step_time) |
|
|
pl_module.log( |
|
|
"TFLOPS", |
|
|
tflops, |
|
|
) |
|
|
|
|
|
def eval_tflops_per_sec_per_gpu(self, train_step_time: List | float | int) -> float: |
|
|
""" |
|
|
Args: |
|
|
train_step_time (Any[List, float, int]): Train step time (in seconds). |
|
|
Step time will be less stable for initial steps (~10 steps)- less |
|
|
accurate measurement |
|
|
Use average step time over several steps for higher accuracy |
|
|
Returns: |
|
|
(float): Model TFLOPs per sec per gpu |
|
|
""" |
|
|
total_flops, flops_per_gpu = self.eval_model_flops() |
|
|
|
|
|
if not isinstance(train_step_time, list): |
|
|
train_step_time = [train_step_time] |
|
|
|
|
|
step_time_arr = np.array(train_step_time) |
|
|
train_step_time = np.mean(step_time_arr[len(step_time_arr) // 2 :]) |
|
|
|
|
|
flops_per_sec_per_gpu = flops_per_gpu / (1e12 * train_step_time) |
|
|
|
|
|
return flops_per_sec_per_gpu, total_flops |
|
|
|
|
|
def eval_model_flops(self) -> Tuple[float, float]: |
|
|
""" |
|
|
Calculate model FLOPs for a given model |
|
|
""" |
|
|
if self.model is not None: |
|
|
model_matches = [model for model in _model_flops_map if model in self.model] |
|
|
self.model = model_matches[0] if len(model_matches) > 0 else self.model |
|
|
if self.model not in _model_flops_map: |
|
|
logging.info(f"FLOPs measurement supported for {list(_model_flops_map.keys())}") |
|
|
raise KeyError(f"Failed to extract valid model name from or missing FLOPs calculations for {self.model}") |
|
|
|
|
|
total_flops = _model_flops_map[self.model](self.flops_config) |
|
|
num_devices = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 |
|
|
flops_per_gpu = total_flops / num_devices |
|
|
|
|
|
return total_flops, flops_per_gpu |
|
|
|
|
|
|
|
|
class MM_FLOPsMeasurementCallback(FLOPsMeasurementCallback): |
|
|
""" |
|
|
Calculate and log FLOPs per second after every ``trainer.log_every_n_steps`` steps for multi-modal models. |
|
|
The following models are supported: |
|
|
hf_clip_vit_l, neva_projection, gpt3, llama2, llama3, nemotron, mixtral, bert, hyena |
|
|
|
|
|
Args: |
|
|
model_name_config_dict (dict): |
|
|
Dictionary containing all the individual model configs that make up the multi-modal model. |
|
|
data_config (pl.LightningDataModule): Data module being used in the experiment. |
|
|
""" |
|
|
|
|
|
higher_is_better = True |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model_name_config_dict: dict, |
|
|
data_config: pl.LightningDataModule, |
|
|
): |
|
|
self.data_cfg = data_config |
|
|
self.flops_config_dict = dict() |
|
|
|
|
|
for model_name, model_cfg in model_name_config_dict.items(): |
|
|
kwargs = dict() |
|
|
kwargs["gbs"] = self.data_cfg.global_batch_size |
|
|
kwargs["hs"] = model_cfg.hidden_size |
|
|
if model_name in ["hf_clip_vit_l"]: |
|
|
kwargs["layers"] = model_cfg.num_hidden_layers |
|
|
kwargs["img_seq_len"] = model_cfg.num_image_embeddings_per_tile |
|
|
kwargs["img_h"] = model_cfg.image_size |
|
|
kwargs["img_w"] = model_cfg.image_size |
|
|
kwargs["patch_dim"] = model_cfg.patch_size |
|
|
kwargs["in_channels"] = model_cfg.num_channels |
|
|
kwargs["class_token_len"] = 1 |
|
|
elif model_name in ["neva_projection"]: |
|
|
kwargs["projector_type"] = model_cfg.projector_type |
|
|
kwargs["ffn_hs"] = model_cfg.ffn_hidden_size |
|
|
kwargs["inp_s"] = model_cfg.input_size |
|
|
|
|
|
kwargs["img_seq_len"] = model_name_config_dict["hf_clip_vit_l"].num_image_embeddings_per_tile |
|
|
elif model_name in ["flux"]: |
|
|
kwargs["layers"] = [model_cfg.num_joint_layers, model_cfg.num_single_layers] |
|
|
kwargs["hs"] = model_cfg.hidden_size |
|
|
kwargs["model_channels"] = model_cfg.model_channels |
|
|
kwargs["inp_s"] = model_cfg.context_dim |
|
|
kwargs["in_channels"] = model_cfg.in_channels |
|
|
kwargs["vec_in_dim"] = model_cfg.vec_in_dim |
|
|
else: |
|
|
kwargs["enc_seq_len"] = model_cfg.seq_length |
|
|
kwargs["layers"] = model_cfg.num_layers |
|
|
kwargs["ffn_hs"] = model_cfg.ffn_hidden_size |
|
|
kwargs["attention_heads"] = model_cfg.num_attention_heads |
|
|
kwargs["moe_router_topk"] = model_cfg.moe_router_topk |
|
|
|
|
|
try: |
|
|
query_groups = model_cfg.num_query_groups |
|
|
if query_groups is None: |
|
|
query_groups = model_cfg.num_attention_heads |
|
|
kwargs["query_groups"] = query_groups |
|
|
except: |
|
|
|
|
|
pass |
|
|
|
|
|
self.flops_config_dict[model_name] = flops_formulas.FLOPSConfig(**kwargs) |
|
|
|
|
|
self.avg_train_step_time = 0 |
|
|
|
|
|
def eval_model_flops(self): |
|
|
""" |
|
|
Calculate model FLOPs for a given model recursively when model has multiple sub-models |
|
|
""" |
|
|
|
|
|
|
|
|
mm_model_flops_map = { |
|
|
**_model_flops_map, |
|
|
"hf_clip_vit_l": flops_formulas.clip_vit_l, |
|
|
"neva_projection": flops_formulas.neva_projection, |
|
|
"flux": flops_formulas.flux, |
|
|
} |
|
|
|
|
|
total_flops = flops_per_gpu = 0 |
|
|
for model_name, flops_cfg in self.flops_config_dict.items(): |
|
|
if model_name not in mm_model_flops_map: |
|
|
logging.info(f"FLOPs measurement supported for {list(mm_model_flops_map.keys())}") |
|
|
raise KeyError( |
|
|
f"Failed to extract valid model name from or missing FLOPs calculations for {model_name}" |
|
|
) |
|
|
total_flops += mm_model_flops_map[model_name](flops_cfg) |
|
|
num_devices = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 |
|
|
flops_per_gpu = total_flops / num_devices |
|
|
|
|
|
return total_flops, flops_per_gpu |
|
|
|