File size: 12,905 Bytes
b386992
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Tuple

import lightning.pytorch as pl
import numpy as np
import torch
from lightning.pytorch.callbacks import Callback

from nemo.collections.llm.gpt.model.base import GPTConfig
from nemo.lightning.pytorch.callbacks import PEFT
from nemo.utils import flops_formulas, logging
from nemo.utils.hyena_flops_formulas import hyena


__all__ = ["FLOPsMeasurementCallback", "MM_FLOPsMeasurementCallback"]

_model_flops_map = {
    "gpt3": flops_formulas.gpt3,
    "llama2": flops_formulas.llama2,
    "llama3": flops_formulas.llama3,
    "llama4": flops_formulas.llama3,  # TODO: add llama4 flops formulas
    "nemotron3": flops_formulas.nemotron,
    "nemotron4": flops_formulas.nemotron,
    "mixtral": flops_formulas.mixtral,
    "bert": flops_formulas.bert,
    "hyena": hyena,
    "deepseekv3": flops_formulas.deepseekv3,
    "transformer": flops_formulas.transformer,
    "qwen3": flops_formulas.qwen3,
    "nemotronh": flops_formulas.nemotronh,
}


class FLOPsMeasurementCallback(Callback):
    """
    Calculate and log FLOPs per second after every ``trainer.log_every_n_steps`` steps.

    Args:
        model_config (GPTConfig): Model parameters.
        data_config (pl.LightningDataModule): Data module being used in the experiment.
        model_name (str): Name of the model being run. The following models are supported:
            gpt3, llama2, llama3, nemotron, mixtral, bert, hyena.


    """

    higher_is_better = True

    def __init__(
        self,
        model_config: GPTConfig,
        data_config: pl.LightningDataModule,
        model_name: str,
    ):
        self.model_cfg = model_config
        self.data_cfg = data_config

        # use config params only when NOT provided explicitly
        self.model = model_name

        gbs = self.data_cfg.global_batch_size
        enc_seq_len = self.model_cfg.seq_length
        hs = self.model_cfg.hidden_size
        layers = self.model_cfg.num_layers
        ffn_hs = self.model_cfg.ffn_hidden_size
        attention_heads = self.model_cfg.num_attention_heads
        moe_router_topk = self.model_cfg.moe_router_topk
        model_pattern = getattr(self.model_cfg, "hybrid_override_pattern", None)
        vocab_size = self.data_cfg.tokenizer.vocab_size if hasattr(self.data_cfg, "tokenizer") else None

        # this handles both- 1. key is present, value is None; 2. key is absent
        query_groups = self.model_cfg.num_query_groups
        if query_groups is None:
            query_groups = attention_heads

        config_kwargs = {
            "gbs": gbs,
            "enc_seq_len": enc_seq_len,
            "hs": hs,
            "layers": layers,
            "ffn_hs": ffn_hs,
            "attention_heads": attention_heads,
            "moe_router_topk": moe_router_topk,
            "query_groups": query_groups,
            "vocab_size": vocab_size,
            "model_pattern": model_pattern,
        }

        from megatron.core.transformer.transformer_config import MLATransformerConfig

        if isinstance(self.model_cfg, MLATransformerConfig):
            config_kwargs["qk_head_dim"] = self.model_cfg.qk_head_dim
            config_kwargs["qk_pos_emb_head_dim"] = self.model_cfg.qk_pos_emb_head_dim
            config_kwargs["v_head_dim"] = self.model_cfg.v_head_dim
            config_kwargs["q_lora_rank"] = self.model_cfg.q_lora_rank
            config_kwargs["kv_lora_rank"] = self.model_cfg.kv_lora_rank
        config_kwargs["moe_layer_freq"] = self.model_cfg.moe_layer_freq
        config_kwargs["moe_shared_expert_intermediate_size"] = self.model_cfg.moe_shared_expert_intermediate_size
        config_kwargs["moe_ffn_hidden_size"] = self.model_cfg.moe_ffn_hidden_size
        config_kwargs["mtp_num_layers"] = self.model_cfg.mtp_num_layers

        if self.model_cfg.is_hybrid_model:
            config_kwargs['is_hybrid_model'] = True
            config_kwargs['hybrid_override_pattern'] = self.model_cfg.hybrid_override_pattern
            config_kwargs['mamba_state_dim'] = self.model_cfg.mamba_state_dim
            config_kwargs['mamba_head_dim'] = self.model_cfg.mamba_head_dim
            config_kwargs['mamba_num_groups'] = self.model_cfg.mamba_num_groups
            config_kwargs['mamba_num_heads'] = self.model_cfg.mamba_num_heads

        self.flops_config = flops_formulas.FLOPSConfig(**config_kwargs)

        self.model = self.model.lower() if self.model is not None else self.model

        self.avg_train_step_time = 0

    def on_train_start(self, trainer, pl_module):
        """
        PyTorch Lightning callback hook. Ensures that user is not using PEFT
        as FLOPS callback does not support it.
        """
        for callback in trainer.callbacks:
            if isinstance(callback, PEFT):
                raise NotImplementedError("FLOPs measurement not supported for finetuning jobs")

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx: int):
        """
        PyTorch Lightning callback hook to calculate TFLOPs per sec per GPU after training
        """
        try:
            self.avg_train_step_time += trainer.progress_bar_metrics['train_step_timing in s']
        except KeyError:
            print("'train_step_timing in s' not found. Make sure to use TimingCallback with FLOPsMeasurementCallback.")
        n = trainer.strategy.current_epoch_step
        if n % trainer.log_every_n_steps == 0:
            # skip calculation if we haven't accumulated any timing data
            if self.avg_train_step_time == 0:
                return
            train_step_time = self.avg_train_step_time / trainer.log_every_n_steps
            tflops_per_gpu, flops = self.eval_tflops_per_sec_per_gpu(train_step_time)
            self.avg_train_step_time = 0
            pl_module.log(
                "TFLOPS_per_GPU",
                tflops_per_gpu,
                on_step=True,
                on_epoch=False,
                batch_size=1,
                prog_bar=True,
            )

            tflops = flops / (1e12 * train_step_time)
            pl_module.log(
                "TFLOPS",
                tflops,
            )

    def eval_tflops_per_sec_per_gpu(self, train_step_time: List | float | int) -> float:
        """
        Args:
            train_step_time (Any[List, float, int]): Train step time (in seconds).
            Step time will be less stable for initial steps (~10 steps)- less
            accurate measurement
            Use average step time over several steps for higher accuracy
        Returns:
            (float): Model TFLOPs per sec per gpu
        """
        total_flops, flops_per_gpu = self.eval_model_flops()

        if not isinstance(train_step_time, list):
            train_step_time = [train_step_time]
        # efficient mean computation if num train steps is very large
        step_time_arr = np.array(train_step_time)
        train_step_time = np.mean(step_time_arr[len(step_time_arr) // 2 :])

        flops_per_sec_per_gpu = flops_per_gpu / (1e12 * train_step_time)

        return flops_per_sec_per_gpu, total_flops

    def eval_model_flops(self) -> Tuple[float, float]:
        """
        Calculate model FLOPs for a given model
        """
        if self.model is not None:
            model_matches = [model for model in _model_flops_map if model in self.model]
            self.model = model_matches[0] if len(model_matches) > 0 else self.model
        if self.model not in _model_flops_map:
            logging.info(f"FLOPs measurement supported for {list(_model_flops_map.keys())}")
            raise KeyError(f"Failed to extract valid model name from or missing FLOPs calculations for {self.model}")

        total_flops = _model_flops_map[self.model](self.flops_config)
        num_devices = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
        flops_per_gpu = total_flops / num_devices

        return total_flops, flops_per_gpu


class MM_FLOPsMeasurementCallback(FLOPsMeasurementCallback):
    """
    Calculate and log FLOPs per second after every ``trainer.log_every_n_steps`` steps for multi-modal models.
    The following models are supported:
            hf_clip_vit_l, neva_projection, gpt3, llama2, llama3, nemotron, mixtral, bert, hyena

    Args:
        model_name_config_dict (dict):
            Dictionary containing all the individual model configs that make up the multi-modal model.
        data_config (pl.LightningDataModule): Data module being used in the experiment.
    """

    higher_is_better = True

    def __init__(
        self,
        model_name_config_dict: dict,
        data_config: pl.LightningDataModule,
    ):
        self.data_cfg = data_config
        self.flops_config_dict = dict()

        for model_name, model_cfg in model_name_config_dict.items():
            kwargs = dict()
            kwargs["gbs"] = self.data_cfg.global_batch_size
            kwargs["hs"] = model_cfg.hidden_size
            if model_name in ["hf_clip_vit_l"]:
                kwargs["layers"] = model_cfg.num_hidden_layers
                kwargs["img_seq_len"] = model_cfg.num_image_embeddings_per_tile
                kwargs["img_h"] = model_cfg.image_size
                kwargs["img_w"] = model_cfg.image_size
                kwargs["patch_dim"] = model_cfg.patch_size
                kwargs["in_channels"] = model_cfg.num_channels
                kwargs["class_token_len"] = 1  # TODO: Add directly to HFCLIPVisionConfig
            elif model_name in ["neva_projection"]:
                kwargs["projector_type"] = model_cfg.projector_type
                kwargs["ffn_hs"] = model_cfg.ffn_hidden_size
                kwargs["inp_s"] = model_cfg.input_size
                # TODO: Add img_seq_len directly to MultimodalProjectorConfig
                kwargs["img_seq_len"] = model_name_config_dict["hf_clip_vit_l"].num_image_embeddings_per_tile
            elif model_name in ["flux"]:
                kwargs["layers"] = [model_cfg.num_joint_layers, model_cfg.num_single_layers]
                kwargs["hs"] = model_cfg.hidden_size
                kwargs["model_channels"] = model_cfg.model_channels
                kwargs["inp_s"] = model_cfg.context_dim
                kwargs["in_channels"] = model_cfg.in_channels
                kwargs["vec_in_dim"] = model_cfg.vec_in_dim
            else:
                kwargs["enc_seq_len"] = model_cfg.seq_length
                kwargs["layers"] = model_cfg.num_layers
                kwargs["ffn_hs"] = model_cfg.ffn_hidden_size
                kwargs["attention_heads"] = model_cfg.num_attention_heads
                kwargs["moe_router_topk"] = model_cfg.moe_router_topk

            try:
                query_groups = model_cfg.num_query_groups
                if query_groups is None:
                    query_groups = model_cfg.num_attention_heads
                kwargs["query_groups"] = query_groups
            except:
                # Multi-modal models use HF model configs which may/may not define num_query_groups
                pass

            self.flops_config_dict[model_name] = flops_formulas.FLOPSConfig(**kwargs)

        self.avg_train_step_time = 0

    def eval_model_flops(self):
        """
        Calculate model FLOPs for a given model recursively when model has multiple sub-models
        """

        # Add Multimodal models supported only by MM_FLOPsMeasurementCallback
        mm_model_flops_map = {
            **_model_flops_map,
            "hf_clip_vit_l": flops_formulas.clip_vit_l,
            "neva_projection": flops_formulas.neva_projection,
            "flux": flops_formulas.flux,
        }

        total_flops = flops_per_gpu = 0
        for model_name, flops_cfg in self.flops_config_dict.items():
            if model_name not in mm_model_flops_map:
                logging.info(f"FLOPs measurement supported for {list(mm_model_flops_map.keys())}")
                raise KeyError(
                    f"Failed to extract valid model name from or missing FLOPs calculations for {model_name}"
                )
            total_flops += mm_model_flops_map[model_name](flops_cfg)
        num_devices = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
        flops_per_gpu = total_flops / num_devices

        return total_flops, flops_per_gpu