File size: 17,658 Bytes
b386992
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from typing import List, Optional, Union

from nemo.collections.common.parts.perf_metrics_utils import LLM_VOCAB_SIZE_MAP


@dataclass
class FLOPSConfig:
    """Contains the model hparams needed for FLOPS computations"""

    gbs: int
    enc_seq_len: Optional[int] = None
    hs: Optional[int] = None
    layers: Optional[int] = None
    ffn_hs: Optional[int] = None
    attention_heads: Optional[int] = None
    moe_router_topk: Optional[int] = None
    query_groups: Optional[int] = None
    img_seq_len: Optional[int] = None
    img_h: Optional[int] = None
    img_w: Optional[int] = None
    in_channels: Optional[int] = None
    patch_dim: Optional[int] = None
    class_token_len: Optional[int] = None
    projector_type: Optional[str] = None
    inp_s: Optional[int] = None
    model_pattern: Optional[str] = None
    vocab_size: Optional[int] = None
    model_channels: Optional[int] = None
    vec_in_dim: Optional[int] = None
    q_lora_rank: Optional[int] = None
    kv_lora_rank: Optional[int] = None
    qk_head_dim: Optional[int] = None
    qk_pos_emb_head_dim: Optional[int] = None
    v_head_dim: Optional[int] = None
    moe_layer_freq: Union[int, List[int]] = None
    moe_shared_expert_intermediate_size: Optional[int] = None
    moe_ffn_hidden_size: Optional[int] = None
    mtp_num_layers: Optional[int] = None
    causal_self_attn: Optional[bool] = None
    is_hybrid_model: bool = False
    hybrid_override_pattern: Optional[str] = None
    mamba_state_dim: Optional[int] = None
    mamba_head_dim: Optional[int] = None
    mamba_num_groups: Optional[int] = None
    mamba_num_heads: Optional[int] = None


def gpt3(config: FLOPSConfig):
    """Model FLOPs for GPT3 family"""

    vocab_size = LLM_VOCAB_SIZE_MAP["gpt3"]

    return (
        24 * config.gbs * config.enc_seq_len * config.hs * config.hs
        + 4 * config.gbs * config.enc_seq_len * config.enc_seq_len * config.hs
    ) * (3 * config.layers) + (6 * config.gbs * config.enc_seq_len * config.hs * vocab_size)


def llama2(config: FLOPSConfig):
    """Model FLOPs for llama2 family"""
    vocab_size = LLM_VOCAB_SIZE_MAP["llama2"]

    return (
        config.gbs
        * config.enc_seq_len
        * config.layers
        * config.hs
        * config.hs
        * (
            12
            + (12 * config.query_groups / config.attention_heads)
            + (18 * config.ffn_hs / config.hs)
            + (12 * config.enc_seq_len / config.hs)
            + (6 * vocab_size / (config.layers * config.hs))
        )
    )


def llama3(config: FLOPSConfig):
    """Model FLOPs for llama3 family"""
    vocab_size = LLM_VOCAB_SIZE_MAP["llama3"]

    return (
        config.gbs
        * config.enc_seq_len
        * config.layers
        * config.hs
        * config.hs
        * (
            12
            + (12 * config.query_groups / config.attention_heads)
            + (18 * config.ffn_hs / config.hs)
            + (12 * config.enc_seq_len / config.hs)
            + (6 * vocab_size / (config.layers * config.hs))
        )
    )


def nemotron(config: FLOPSConfig):
    """Model FLOPs for nemotron family"""
    vocab_size = LLM_VOCAB_SIZE_MAP["nemotron"]

    return (
        config.gbs
        * config.enc_seq_len
        * config.layers
        * config.hs
        * config.hs
        * (
            12
            + (12 * config.query_groups / config.attention_heads)
            + (12 * config.ffn_hs / config.hs)
            + (12 * config.enc_seq_len / config.hs)
            + (6 * vocab_size / (config.layers * config.hs))
        )
    )


def mixtral(config: FLOPSConfig):
    """Model FLOPs for mixtral family"""
    vocab_size = LLM_VOCAB_SIZE_MAP["mixtral"]

    return (
        config.gbs
        * config.enc_seq_len
        * config.layers
        * config.hs
        * config.hs
        * (
            12
            + (12 * config.query_groups / config.attention_heads)
            + (18 * config.moe_router_topk * config.ffn_hs / config.hs)
            + (12 * config.enc_seq_len / config.hs)
            + (6 * vocab_size / (config.layers * config.hs))
        )
    )


def qwen3(config: FLOPSConfig):
    """Model FLOPs for Qwen3 family"""
    causal_self_attn = True
    seq_len = config.enc_seq_len
    hidden_size = config.hs
    gated_linear_multiplier = 2

    # attention flops for GQA
    attention_flops = (
        3
        * 2
        * config.gbs
        * config.layers
        * seq_len
        * hidden_size
        * hidden_size
        * (
            (config.query_groups / config.attention_heads * 2 + 1)  # QKV gemm
            + (seq_len / hidden_size * 2 * (0.5 if causal_self_attn else 1))  # attention
            + 1  # attention proj gemm
        )
    )

    # mlp flops
    mlp_flops = (
        3
        * 2
        * config.gbs
        * config.layers
        * seq_len
        * hidden_size
        * (1 + gated_linear_multiplier)
        * (config.moe_ffn_hidden_size * config.moe_router_topk)  # MoE layers
    )

    # vocab flops
    vocab_flops = 3 * 2 * config.gbs * seq_len * hidden_size * config.vocab_size

    return attention_flops + mlp_flops + vocab_flops


def bert(config: FLOPSConfig):
    """Model FLOPs for BERT family"""
    vocab_size = LLM_VOCAB_SIZE_MAP["bert"]

    return (
        72
        * config.gbs
        * config.layers
        * config.enc_seq_len
        * config.hs
        * config.hs
        * (1 + (config.enc_seq_len / (6 * config.hs)) + (vocab_size / (12 * config.hs * config.layers)))
    )


def transformer(config: FLOPSConfig):
    """Calculate FLOPs for a standard Transformer model.
    Note: This does not cover encoder-decoder models.
    """
    # Extract parameters from config
    batch_size = config.gbs
    hidden_size = config.hs
    seq_length = config.enc_seq_len
    num_layers = config.layers
    num_attention_heads = config.attention_heads
    ffn_hidden_size = config.ffn_hs
    vocab_size = config.vocab_size

    if vocab_size is None:
        raise ValueError("vocab_size is required for transformer FLOPs calculation")

    # Handle optional parameters with reasonable defaults
    query_groups = config.query_groups if config.query_groups is not None else num_attention_heads
    causal_self_attn = config.causal_self_attn if config.causal_self_attn is not None else False
    moe_router_topk = config.moe_router_topk if config.moe_router_topk is not None else 0
    kv_channels = hidden_size // num_attention_heads  # Standard dimension per head

    # Calculate query projection size and ratio
    query_projection_size = kv_channels * num_attention_heads
    query_projection_to_hidden_size_ratio = query_projection_size / hidden_size

    # MoE parameters - simplified for NeMo config
    # In this implementation, we assume all layers are dense if num_experts is None
    if moe_router_topk == 0:
        num_dense_layers = num_layers
        num_moe_layers = 0
        num_experts_routed_to = 0
    else:
        # Simplified MoE handling - assuming uniform distribution of MoE layers
        # This can be expanded based on NeMo's actual MoE implementation
        num_moe_layers = num_layers // 2  # Simplified assumption
        num_dense_layers = num_layers - num_moe_layers
        num_experts_routed_to = moe_router_topk

    # Handle SwiGLU vs standard GELU/ReLU
    # Default to standard activation (no SwiGLU)
    gated_linear_multiplier = 1

    # Define the expansion factor as described in the paper
    # 3x: Each GEMM needs forward pass, backward wgrad, and backward dgrad
    # 2x: GEMMs are stacked twice in standard Transformer architectures
    # 2x: A GEMM of m*n with n*k requires 2mnk floating-point operations
    expansion_factor = 3 * 2 * 2
    # Attention
    if not causal_self_attn:
        attention_component = (
            1
            + (query_groups / num_attention_heads)
            # Only half of the attention matrix is non-zero and needs to be multiplied with V
            + (seq_length / hidden_size)  # If causal self attn -> divide by 2.
        ) * query_projection_to_hidden_size_ratio
    else:
        attention_component = (
            1
            + (query_groups / num_attention_heads)
            # Only half of the attention matrix is non-zero and needs to be multiplied with V
            + (seq_length / hidden_size / 2)  # If causal self attn -> divide by 2.
        ) * query_projection_to_hidden_size_ratio

    # Calculate total FLOPs
    total_flops = (
        expansion_factor
        * batch_size
        * seq_length
        * num_layers
        * hidden_size
        * hidden_size
        * (
            attention_component
            # MLP component
            + (
                (
                    # Dense layers
                    (ffn_hidden_size * num_dense_layers)
                    +
                    # MoE layers
                    (
                        (
                            # Routed experts
                            ffn_hidden_size
                            * num_experts_routed_to
                            # Note: Shared experts are not implemented in this version
                        )
                        * num_moe_layers
                    )
                )
                * gated_linear_multiplier
                / (num_layers * hidden_size)
            )
            # Logit component
            + (vocab_size / (2 * num_layers * hidden_size))
        )
    )

    return total_flops


def clip_vit_l(config: FLOPSConfig):
    """Model FLOPs for CLIP ViT"""

    if config.img_seq_len is None:
        config.img_seq_len = (config.img_h * config.img_w) / (
            config.patch_dim * config.patch_dim
        ) + config.class_token_len
    return config.gbs * config.layers * config.hs * config.hs * config.img_seq_len * (
        24 + (4 * config.img_seq_len / config.hs)
    ) + (2 * config.gbs * config.hs * config.in_channels * config.img_h * config.img_w)


def neva_projection(config: FLOPSConfig):
    """Model FLOPs for NeVA Projection"""

    if "mlp" in config.projector_type:
        return 6 * config.gbs * config.img_seq_len * config.ffn_hs * (config.inp_s + config.hs)
    elif config.projector_type == "affine":
        return 6 * config.gbs * config.img_seq_len * config.inp_s * config.hs
    else:
        raise ValueError(
            f"NeVA Projections FLOPs calculator only supports 'mlp', 'mcore_mlp'"
            f" or 'affine' projector_type but found {config.projector_type}"
        )


def flux(config: FLOPSConfig):
    """Model FLOPs for FLUX"""

    hs = config.hs
    seq_len = config.model_channels + config.inp_s
    base_factor = 6 * config.gbs  # common multiplier for most terms

    # Joint layer computations
    joint_layer_flops = (
        base_factor
        * config.layers[0]
        * (
            10 * hs * hs  # hidden size operations
            + 2 * hs * (config.model_channels + config.inp_s) * (1 + hs * 7)  # channel and context joint attention
            + 2 * (config.model_channels + config.inp_s) * hs  # final projection
        )
    )

    # Single layer computations
    single_layer_flops = (
        base_factor
        * config.layers[1]
        * seq_len
        * hs
        * (
            3  # linear Y
            + 1  # Modulation
            + 4 * hs  # Linear computations
            + (3 * hs + 2 * seq_len)  # attention operations
            + 5 * hs  # feed-forward
            + 1  # Modulation
        )
    )

    # Embedding and projection layers
    other_flops = base_factor * (
        config.inp_s * config.in_channels * hs  # image embedding
        + config.inp_s * hs * config.model_channels  # text embedding
        + config.vec_in_dim * hs
        + hs * hs  # vector embedding
        + 2 * (config.model_channels * hs + hs * hs)  # guidance + timestep embedding
        + (config.inp_s * config.in_channels * hs) / config.gbs  # final projection
    )

    return joint_layer_flops + single_layer_flops + other_flops


def deepseekv3(config: FLOPSConfig):
    """Model FLOPs for DeepSeek V3"""

    # self-attention flops
    bmm1_flops = (
        0.5 * (config.qk_head_dim + config.qk_pos_emb_head_dim) * config.attention_heads * (config.enc_seq_len**2)
    )
    bmm2_flops = 0.5 * config.v_head_dim * config.attention_heads * (config.enc_seq_len**2)
    per_input_attention_flops = 6 * (bmm1_flops + bmm2_flops) * config.layers
    if config.mtp_num_layers is not None:
        per_input_attention_flops += 6 * (bmm1_flops + bmm2_flops) * config.mtp_num_layers

    # linear layer flops
    per_layer_mla_params = config.hs * config.q_lora_rank + config.q_lora_rank * (
        (config.qk_head_dim + config.qk_pos_emb_head_dim) * config.attention_heads
    )  # Q
    per_layer_mla_params += config.hs * config.qk_pos_emb_head_dim  # K^R
    per_layer_mla_params += config.hs * config.kv_lora_rank + config.kv_lora_rank * (
        (config.qk_head_dim + config.v_head_dim) * config.attention_heads
    )  # K^C and V^C
    per_layer_mla_params += config.v_head_dim * config.attention_heads * config.hs  # Proj
    mla_params = per_layer_mla_params * config.layers
    if config.mtp_num_layers is not None:
        mla_params += per_layer_mla_params * config.mtp_num_layers

    dense_layer_ffn_params = config.hs * config.ffn_hs * 3  # gated linear unit
    per_shared_expert_params = config.hs * config.moe_shared_expert_intermediate_size * 3
    per_selected_expert_params = config.hs * config.moe_ffn_hidden_size * 3
    ffn_params = 0

    if isinstance(config.moe_layer_freq, int):
        moe_layer_pattern = [1 if (i % config.moe_layer_freq == 0) else 0 for i in range(config.layers)]
    else:
        moe_layer_pattern = config.moe_layer_freq
    for i in moe_layer_pattern:
        if i == 0:
            ffn_params += dense_layer_ffn_params
        else:
            ffn_params += per_shared_expert_params + (per_selected_expert_params * config.moe_router_topk)
    if config.mtp_num_layers is not None:
        for i in range(config.mtp_num_layers):
            ffn_params += per_shared_expert_params + (per_selected_expert_params * config.moe_router_topk)
    per_input_params = mla_params + ffn_params
    per_input_linear_flops = 6 * per_input_params * config.enc_seq_len

    # vocab flops
    per_input_vocab_flops = 6 * config.vocab_size * config.hs * config.enc_seq_len
    if config.mtp_num_layers is not None:
        for i in range(config.mtp_num_layers):
            per_input_vocab_flops += 6 * config.vocab_size * config.hs * config.enc_seq_len
            per_input_vocab_flops += 6 * config.hs * 2 * config.hs * config.enc_seq_len

    return (per_input_attention_flops + per_input_linear_flops + per_input_vocab_flops) * config.gbs


def _nemotronh_mlp_layer_flops(config: FLOPSConfig):
    """Model FLOPs for MLP layer. Assume gated linear unit."""
    return 6 * config.gbs * config.enc_seq_len * config.hs * config.ffn_hs * 3


def _non_mla_attn_layer_flops(config: FLOPSConfig):
    """Model FLOPs for attention layer"""
    return (
        6
        * config.gbs
        * config.enc_seq_len
        * config.hs
        * (
            config.hs  # Q
            + config.query_groups / config.attention_heads * config.hs * 2  # KV
            + config.enc_seq_len / 2 * 2
            + config.hs
        )
    )


def _mamba_layer_flops(config: FLOPSConfig):
    """Model FLOPs for Mamba layer. We ignore part of the flops of scan because the
    chunk size is not known from model config."""
    assert config.mamba_state_dim is not None
    assert config.mamba_head_dim is not None

    if config.mamba_num_heads:
        nheads = config.mamba_num_heads
    else:
        nheads = 2 * config.hs // config.mamba_head_dim  # default expand is 2
    d_in = nheads * config.mamba_head_dim
    return (
        (
            6
            * config.gbs
            * config.enc_seq_len
            * config.hs
            * (2 * d_in + 2 * config.mamba_num_groups * config.mamba_state_dim + nheads)
        )
        + (3 * 2 * config.gbs * config.enc_seq_len * d_in * config.mamba_state_dim)
        + (6 * config.gbs * config.enc_seq_len * d_in * config.hs)
    )


def _hybrid_model_flops(config: FLOPSConfig):
    """Model FLOPs for hybrid model"""
    assert config.is_hybrid_model == True
    assert config.hybrid_override_pattern is not None

    num_attn_layers, num_mamba_layers, num_mlp_layers = 0, 0, 0
    for c in config.hybrid_override_pattern:
        if c == 'M':
            num_mamba_layers += 1
        elif c == '-':
            num_mlp_layers += 1
        elif c == '*':
            num_attn_layers += 1
    return (
        num_attn_layers * _non_mla_attn_layer_flops(config)
        + num_mamba_layers * _mamba_layer_flops(config)
        + num_mlp_layers * _nemotronh_mlp_layer_flops(config)
        + 6 * config.gbs * config.enc_seq_len * config.hs * config.vocab_size
    )


def nemotronh(config: FLOPSConfig):
    """Model FLOPs for NemotronH"""
    return _hybrid_model_flops(config)