# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List, Tuple import lightning.pytorch as pl import numpy as np import torch from lightning.pytorch.callbacks import Callback from nemo.collections.llm.gpt.model.base import GPTConfig from nemo.lightning.pytorch.callbacks import PEFT from nemo.utils import flops_formulas, logging from nemo.utils.hyena_flops_formulas import hyena __all__ = ["FLOPsMeasurementCallback", "MM_FLOPsMeasurementCallback"] _model_flops_map = { "gpt3": flops_formulas.gpt3, "llama2": flops_formulas.llama2, "llama3": flops_formulas.llama3, "llama4": flops_formulas.llama3, # TODO: add llama4 flops formulas "nemotron3": flops_formulas.nemotron, "nemotron4": flops_formulas.nemotron, "mixtral": flops_formulas.mixtral, "bert": flops_formulas.bert, "hyena": hyena, "deepseekv3": flops_formulas.deepseekv3, "transformer": flops_formulas.transformer, "qwen3": flops_formulas.qwen3, "nemotronh": flops_formulas.nemotronh, } class FLOPsMeasurementCallback(Callback): """ Calculate and log FLOPs per second after every ``trainer.log_every_n_steps`` steps. Args: model_config (GPTConfig): Model parameters. data_config (pl.LightningDataModule): Data module being used in the experiment. model_name (str): Name of the model being run. The following models are supported: gpt3, llama2, llama3, nemotron, mixtral, bert, hyena. """ higher_is_better = True def __init__( self, model_config: GPTConfig, data_config: pl.LightningDataModule, model_name: str, ): self.model_cfg = model_config self.data_cfg = data_config # use config params only when NOT provided explicitly self.model = model_name gbs = self.data_cfg.global_batch_size enc_seq_len = self.model_cfg.seq_length hs = self.model_cfg.hidden_size layers = self.model_cfg.num_layers ffn_hs = self.model_cfg.ffn_hidden_size attention_heads = self.model_cfg.num_attention_heads moe_router_topk = self.model_cfg.moe_router_topk model_pattern = getattr(self.model_cfg, "hybrid_override_pattern", None) vocab_size = self.data_cfg.tokenizer.vocab_size if hasattr(self.data_cfg, "tokenizer") else None # this handles both- 1. key is present, value is None; 2. key is absent query_groups = self.model_cfg.num_query_groups if query_groups is None: query_groups = attention_heads config_kwargs = { "gbs": gbs, "enc_seq_len": enc_seq_len, "hs": hs, "layers": layers, "ffn_hs": ffn_hs, "attention_heads": attention_heads, "moe_router_topk": moe_router_topk, "query_groups": query_groups, "vocab_size": vocab_size, "model_pattern": model_pattern, } from megatron.core.transformer.transformer_config import MLATransformerConfig if isinstance(self.model_cfg, MLATransformerConfig): config_kwargs["qk_head_dim"] = self.model_cfg.qk_head_dim config_kwargs["qk_pos_emb_head_dim"] = self.model_cfg.qk_pos_emb_head_dim config_kwargs["v_head_dim"] = self.model_cfg.v_head_dim config_kwargs["q_lora_rank"] = self.model_cfg.q_lora_rank config_kwargs["kv_lora_rank"] = self.model_cfg.kv_lora_rank config_kwargs["moe_layer_freq"] = self.model_cfg.moe_layer_freq config_kwargs["moe_shared_expert_intermediate_size"] = self.model_cfg.moe_shared_expert_intermediate_size config_kwargs["moe_ffn_hidden_size"] = self.model_cfg.moe_ffn_hidden_size config_kwargs["mtp_num_layers"] = self.model_cfg.mtp_num_layers if self.model_cfg.is_hybrid_model: config_kwargs['is_hybrid_model'] = True config_kwargs['hybrid_override_pattern'] = self.model_cfg.hybrid_override_pattern config_kwargs['mamba_state_dim'] = self.model_cfg.mamba_state_dim config_kwargs['mamba_head_dim'] = self.model_cfg.mamba_head_dim config_kwargs['mamba_num_groups'] = self.model_cfg.mamba_num_groups config_kwargs['mamba_num_heads'] = self.model_cfg.mamba_num_heads self.flops_config = flops_formulas.FLOPSConfig(**config_kwargs) self.model = self.model.lower() if self.model is not None else self.model self.avg_train_step_time = 0 def on_train_start(self, trainer, pl_module): """ PyTorch Lightning callback hook. Ensures that user is not using PEFT as FLOPS callback does not support it. """ for callback in trainer.callbacks: if isinstance(callback, PEFT): raise NotImplementedError("FLOPs measurement not supported for finetuning jobs") def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx: int): """ PyTorch Lightning callback hook to calculate TFLOPs per sec per GPU after training """ try: self.avg_train_step_time += trainer.progress_bar_metrics['train_step_timing in s'] except KeyError: print("'train_step_timing in s' not found. Make sure to use TimingCallback with FLOPsMeasurementCallback.") n = trainer.strategy.current_epoch_step if n % trainer.log_every_n_steps == 0: # skip calculation if we haven't accumulated any timing data if self.avg_train_step_time == 0: return train_step_time = self.avg_train_step_time / trainer.log_every_n_steps tflops_per_gpu, flops = self.eval_tflops_per_sec_per_gpu(train_step_time) self.avg_train_step_time = 0 pl_module.log( "TFLOPS_per_GPU", tflops_per_gpu, on_step=True, on_epoch=False, batch_size=1, prog_bar=True, ) tflops = flops / (1e12 * train_step_time) pl_module.log( "TFLOPS", tflops, ) def eval_tflops_per_sec_per_gpu(self, train_step_time: List | float | int) -> float: """ Args: train_step_time (Any[List, float, int]): Train step time (in seconds). Step time will be less stable for initial steps (~10 steps)- less accurate measurement Use average step time over several steps for higher accuracy Returns: (float): Model TFLOPs per sec per gpu """ total_flops, flops_per_gpu = self.eval_model_flops() if not isinstance(train_step_time, list): train_step_time = [train_step_time] # efficient mean computation if num train steps is very large step_time_arr = np.array(train_step_time) train_step_time = np.mean(step_time_arr[len(step_time_arr) // 2 :]) flops_per_sec_per_gpu = flops_per_gpu / (1e12 * train_step_time) return flops_per_sec_per_gpu, total_flops def eval_model_flops(self) -> Tuple[float, float]: """ Calculate model FLOPs for a given model """ if self.model is not None: model_matches = [model for model in _model_flops_map if model in self.model] self.model = model_matches[0] if len(model_matches) > 0 else self.model if self.model not in _model_flops_map: logging.info(f"FLOPs measurement supported for {list(_model_flops_map.keys())}") raise KeyError(f"Failed to extract valid model name from or missing FLOPs calculations for {self.model}") total_flops = _model_flops_map[self.model](self.flops_config) num_devices = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 flops_per_gpu = total_flops / num_devices return total_flops, flops_per_gpu class MM_FLOPsMeasurementCallback(FLOPsMeasurementCallback): """ Calculate and log FLOPs per second after every ``trainer.log_every_n_steps`` steps for multi-modal models. The following models are supported: hf_clip_vit_l, neva_projection, gpt3, llama2, llama3, nemotron, mixtral, bert, hyena Args: model_name_config_dict (dict): Dictionary containing all the individual model configs that make up the multi-modal model. data_config (pl.LightningDataModule): Data module being used in the experiment. """ higher_is_better = True def __init__( self, model_name_config_dict: dict, data_config: pl.LightningDataModule, ): self.data_cfg = data_config self.flops_config_dict = dict() for model_name, model_cfg in model_name_config_dict.items(): kwargs = dict() kwargs["gbs"] = self.data_cfg.global_batch_size kwargs["hs"] = model_cfg.hidden_size if model_name in ["hf_clip_vit_l"]: kwargs["layers"] = model_cfg.num_hidden_layers kwargs["img_seq_len"] = model_cfg.num_image_embeddings_per_tile kwargs["img_h"] = model_cfg.image_size kwargs["img_w"] = model_cfg.image_size kwargs["patch_dim"] = model_cfg.patch_size kwargs["in_channels"] = model_cfg.num_channels kwargs["class_token_len"] = 1 # TODO: Add directly to HFCLIPVisionConfig elif model_name in ["neva_projection"]: kwargs["projector_type"] = model_cfg.projector_type kwargs["ffn_hs"] = model_cfg.ffn_hidden_size kwargs["inp_s"] = model_cfg.input_size # TODO: Add img_seq_len directly to MultimodalProjectorConfig kwargs["img_seq_len"] = model_name_config_dict["hf_clip_vit_l"].num_image_embeddings_per_tile elif model_name in ["flux"]: kwargs["layers"] = [model_cfg.num_joint_layers, model_cfg.num_single_layers] kwargs["hs"] = model_cfg.hidden_size kwargs["model_channels"] = model_cfg.model_channels kwargs["inp_s"] = model_cfg.context_dim kwargs["in_channels"] = model_cfg.in_channels kwargs["vec_in_dim"] = model_cfg.vec_in_dim else: kwargs["enc_seq_len"] = model_cfg.seq_length kwargs["layers"] = model_cfg.num_layers kwargs["ffn_hs"] = model_cfg.ffn_hidden_size kwargs["attention_heads"] = model_cfg.num_attention_heads kwargs["moe_router_topk"] = model_cfg.moe_router_topk try: query_groups = model_cfg.num_query_groups if query_groups is None: query_groups = model_cfg.num_attention_heads kwargs["query_groups"] = query_groups except: # Multi-modal models use HF model configs which may/may not define num_query_groups pass self.flops_config_dict[model_name] = flops_formulas.FLOPSConfig(**kwargs) self.avg_train_step_time = 0 def eval_model_flops(self): """ Calculate model FLOPs for a given model recursively when model has multiple sub-models """ # Add Multimodal models supported only by MM_FLOPsMeasurementCallback mm_model_flops_map = { **_model_flops_map, "hf_clip_vit_l": flops_formulas.clip_vit_l, "neva_projection": flops_formulas.neva_projection, "flux": flops_formulas.flux, } total_flops = flops_per_gpu = 0 for model_name, flops_cfg in self.flops_config_dict.items(): if model_name not in mm_model_flops_map: logging.info(f"FLOPs measurement supported for {list(mm_model_flops_map.keys())}") raise KeyError( f"Failed to extract valid model name from or missing FLOPs calculations for {model_name}" ) total_flops += mm_model_flops_map[model_name](flops_cfg) num_devices = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 flops_per_gpu = total_flops / num_devices return total_flops, flops_per_gpu