File size: 9,740 Bytes
bcdf9fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 |
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from transformers import PretrainedConfig
VALID_CONFIG_TYPE = {"llama", "qwen2", "qwen2_vl", "qwen2_5_vl", "qwen3", "qwen3_moe", "deepseek_v3"}
def get_device_flops(unit="T"):
def unit_convert(number, level):
units = ["B", "K", "M", "G", "T", "P"]
if number <= 0:
return number
ptr = 0
while ptr < len(units) and units[ptr] != level:
number /= 1000
ptr += 1
return number
device_name = torch.cuda.get_device_name()
flops = float("inf") # INF flops for unkown gpu type
if "MI300X" in device_name:
flops = 1336e12
elif "H100" in device_name or "H800" in device_name:
flops = 989e12
elif "A100" in device_name or "A800" in device_name:
flops = 312e12
elif "L40" in device_name:
flops = 181.05e12
elif "L20" in device_name:
flops = 119.5e12
elif "H20" in device_name:
flops = 148e12
elif "910B" in device_name:
flops = 354e12
flops_unit = unit_convert(flops, unit)
return flops_unit
class FlopsCounter:
"""
Used to count mfu during training loop
Example:
flops_counter = FlopsCounter(config)
flops_achieved, flops_promised = flops_counter.estimate_flops(tokens_list, delta_time)
"""
def __init__(self, config: PretrainedConfig):
if config.model_type not in VALID_CONFIG_TYPE:
print(f"Only support config type of {VALID_CONFIG_TYPE}, but got {config.model_type}. MFU will always be zero.")
self.estimate_func = {
"qwen2": self._estimate_qwen2_flops,
"llama": self._estimate_qwen2_flops,
"qwen2_vl": self._estimate_qwen2_flops,
"qwen2_5_vl": self._estimate_qwen2_flops,
"qwen3": self._estimate_qwen2_flops,
"qwen3_moe": self._estimate_qwen3_moe_flops,
"deepseek_v3": self._estimate_deepseek_v3_flops,
}
self.config = config
def _estimate_unknown_flops(self, tokens_sum, batch_seqlens, delta_time):
return 0
def _estimate_qwen2_flops(self, tokens_sum, batch_seqlens, delta_time):
hidden_size = self.config.hidden_size
vocab_size = self.config.vocab_size
num_hidden_layers = self.config.num_hidden_layers
num_key_value_heads = self.config.num_key_value_heads
num_attention_heads = self.config.num_attention_heads
intermediate_size = self.config.intermediate_size
head_dim = getattr(self.config, "head_dim", self.config.hidden_size // self.config.num_attention_heads)
q_size = num_attention_heads * head_dim
k_size = num_key_value_heads * head_dim
v_size = num_key_value_heads * head_dim
# non-attn per layer parm
# Qwen2/LLama use SwiGelu, gate, having up and down linear layer in mlp
mlp_N = hidden_size * intermediate_size * 3
attn_linear_N = hidden_size * (q_size + k_size + v_size + num_attention_heads * head_dim)
emd_and_lm_head_N = vocab_size * hidden_size * 2
# non-attn all_layer parm
dense_N = (mlp_N + attn_linear_N) * num_hidden_layers + emd_and_lm_head_N
# non-attn all_layer & all_token fwd & bwd flops
dense_N_flops = 6 * dense_N * tokens_sum
# attn all_layer & all_token fwd & bwd flops
seqlen_square_sum = 0
for seqlen in batch_seqlens:
seqlen_square_sum += seqlen * seqlen
attn_qkv_flops = 12 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers
# all_layer & all_token fwd & bwd flops
flops_all_token = dense_N_flops + attn_qkv_flops
flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12
return flops_achieved
def _estimate_deepseek_v3_flops(self, tokens_sum, batch_seqlens, delta_time):
hidden_size = self.config.hidden_size
vocab_size = self.config.vocab_size
moe_intermediate_size = self.config.moe_intermediate_size
num_hidden_layers = self.config.num_hidden_layers
first_k_dense_replace = self.config.first_k_dense_replace
num_query_heads = self.config.num_attention_heads
moe_num_expert = self.config.n_routed_experts
moe_topk = self.config.num_experts_per_tok
share_expert_num = self.config.n_shared_experts
# non-attn per layer parm
moe_gata_N = hidden_size * moe_num_expert
# moe has fc1_1, fc1_2 and fc2 using SwiGLU in ExpertMlp layer & shared experts
moe_expertmlp_N = hidden_size * moe_intermediate_size * (moe_topk + share_expert_num) * 3
# MLA attn
attn_linear_N = 0
q_head_dim = self.config.qk_nope_head_dim + self.config.qk_rope_head_dim
if self.config.q_lora_rank is None:
attn_linear_N += hidden_size * num_query_heads * q_head_dim
else:
attn_linear_N += hidden_size * self.config.q_lora_rank
attn_linear_N += num_query_heads * q_head_dim * self.config.q_lora_rank
attn_linear_N += hidden_size * (self.config.kv_lora_rank + self.config.qk_rope_head_dim)
attn_linear_N += num_query_heads * (q_head_dim - self.config.qk_rope_head_dim + self.config.v_head_dim) * self.config.kv_lora_rank
attn_linear_N += num_query_heads * self.config.v_head_dim * hidden_size
emd_and_lm_head_N = vocab_size * hidden_size * 2
# non-attn all_layer parm
moe_N = (moe_gata_N + moe_expertmlp_N + attn_linear_N) * (num_hidden_layers - first_k_dense_replace) + (hidden_size * self.config.intermediate_size * 3 + attn_linear_N) * first_k_dense_replace + emd_and_lm_head_N
# non-attn all_layer & all_token fwd & bwd flops
dense_N_flops = 6 * moe_N * tokens_sum
# attn all_layer & all_token fwd & bwd flops
seqlen_square_sum = 0
for seqlen in batch_seqlens:
seqlen_square_sum += seqlen * seqlen * num_hidden_layers
attn_qkv_flops = 12 * seqlen_square_sum * q_head_dim * num_query_heads
# all_layer & all_token fwd & bwk flops
flops_all_token = dense_N_flops + attn_qkv_flops
flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12
return flops_achieved
def _estimate_qwen3_moe_flops(self, tokens_sum, batch_seqlens, delta_time):
hidden_size = self.config.hidden_size
vocab_size = self.config.vocab_size
num_hidden_layers = self.config.num_hidden_layers
num_key_value_heads = self.config.num_key_value_heads
num_attention_heads = self.config.num_attention_heads
moe_intermediate_size = self.config.moe_intermediate_size
moe_topk = self.config.num_experts_per_tok
num_experts = self.config.num_experts
head_dim = getattr(self.config, "head_dim", self.config.hidden_size // self.config.num_attention_heads)
q_size = num_attention_heads * head_dim
k_size = num_key_value_heads * head_dim
v_size = num_key_value_heads * head_dim
# non-attn per layer parm
# gate + moe export
moe_mlp_N = hidden_size * moe_topk * moe_intermediate_size * 3 + hidden_size * num_experts
attn_linear_N = hidden_size * (q_size + k_size + v_size + num_attention_heads * head_dim)
emd_and_lm_head_N = vocab_size * hidden_size * 2
# non-attn all_layer parm
dense_N = (moe_mlp_N + attn_linear_N) * num_hidden_layers + emd_and_lm_head_N
# non-attn all_layer & all_token fwd & bwd flops
dense_N_flops = 6 * dense_N * tokens_sum
# attn all_layer & all_token fwd & bwd flops
seqlen_square_sum = 0
for seqlen in batch_seqlens:
seqlen_square_sum += seqlen * seqlen
attn_qkv_flops = 12 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers
# all_layer & all_token fwd & bwd flops
flops_all_token = dense_N_flops + attn_qkv_flops
flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12
return flops_achieved
def estimate_flops(self, batch_seqlens, delta_time):
"""
Estimate the FLOPS based on the number of valid tokens in the current batch and the time taken.
Args:
batch_seqlens (List[int]): A list where each element represents the number of valid tokens in the current batch.
delta_time (float): The time taken to process the batch, in seconds.
Returns:
estimated_flops (float): The estimated FLOPS based on the input tokens and time.
promised_flops (float): The expected FLOPS of the current device.
"""
tokens_sum = sum(batch_seqlens)
func = self.estimate_func.get(self.config.model_type, self._estimate_unknown_flops)
estimated_flops = func(tokens_sum, batch_seqlens, delta_time)
promised_flops = get_device_flops()
return estimated_flops, promised_flops
|