File size: 5,150 Bytes
7155cf2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, List, Tuple
import torch
if TYPE_CHECKING:
from transformers.models.llama.configuration_llama import LlamaConfig
VALID_MODLE_TYPE = {"llama", "qwen2", "qwen2_vl", "qwen2_5_vl"}
def get_device_flops(unit: str = "T") -> float:
def unit_convert(number: float, level: str):
units = ["B", "K", "M", "G", "T", "P"]
if number <= 0:
return number
ptr = 0
while ptr < len(units) and units[ptr] != level:
number /= 1000
ptr += 1
return number
device_name = torch.cuda.get_device_name()
flops = float("inf") # INF flops for unkown gpu type
if "H100" in device_name or "H800" in device_name:
flops = 989e12
elif "A100" in device_name or "A800" in device_name:
flops = 312e12
elif "L40" in device_name:
flops = 181.05e12
elif "L20" in device_name:
flops = 119.5e12
elif "H20" in device_name:
flops = 148e12
elif "910B" in device_name:
flops = 354e12
flops_unit = unit_convert(flops, unit)
return flops_unit
class FlopsCounter:
"""
Used to count mfu during training loop
Example:
flops_counter = FlopsCounter(config)
flops_achieved, flops_promised = flops_counter.estimate_flops(tokens_list, delta_time)
"""
def __init__(self, config: "LlamaConfig"):
if config.model_type not in VALID_MODLE_TYPE:
print(f"Only support {VALID_MODLE_TYPE}, but got {config.model_type}. MFU will always be zero.")
self.estimate_func = {
"llama": self._estimate_llama_flops,
"qwen2": self._estimate_llama_flops,
"qwen2_vl": self._estimate_llama_flops,
"qwen2_5_vl": self._estimate_llama_flops,
}
self.config = config
def _estimate_unknown_flops(self, tokens_sum: int, batch_seqlens: List[int], delta_time: float) -> float:
return 0
def _estimate_llama_flops(self, tokens_sum: int, batch_seqlens: List[int], delta_time: float) -> float:
hidden_size = self.config.hidden_size
vocab_size = self.config.vocab_size
num_hidden_layers = self.config.num_hidden_layers
num_key_value_heads = self.config.num_key_value_heads
num_attention_heads = self.config.num_attention_heads
intermediate_size = self.config.intermediate_size
head_dim = hidden_size // num_attention_heads
q_size = num_attention_heads * head_dim
k_size = num_key_value_heads * head_dim
v_size = num_key_value_heads * head_dim
# non-attn per layer parm
# Qwen2/LLama use SwiGelu, gate, having up and down linear layer in mlp
mlp_N = hidden_size * intermediate_size * 3
attn_linear_N = hidden_size * (q_size + k_size + v_size + num_attention_heads * head_dim)
emd_and_lm_head_N = vocab_size * hidden_size * 2
# non-attn all_layer parm
dense_N = (mlp_N + attn_linear_N) * num_hidden_layers + emd_and_lm_head_N
# non-attn all_layer & all_token fwd & bwd flops
dense_N_flops = 6 * dense_N * tokens_sum
# attn all_layer & all_token fwd & bwd flops
seqlen_square_sum = 0
for seqlen in batch_seqlens:
seqlen_square_sum += seqlen * seqlen
attn_qkv_flops = 12 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers
# all_layer & all_token fwd & bwd flops
flops_all_token = dense_N_flops + attn_qkv_flops
flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12
return flops_achieved
def estimate_flops(self, batch_seqlens: List[int], delta_time: float) -> Tuple[float, float]:
"""
Estimate the FLOPS based on the number of valid tokens in the current batch and the time taken.
Args:
batch_seqlens (List[int]): A list where each element represents the number of valid tokens in the current batch.
delta_time (float): The time taken to process the batch, in seconds.
Returns:
estimated_flops (float): The estimated FLOPS based on the input tokens and time.
promised_flops (float): The expected FLOPS of the current device.
"""
tokens_sum = sum(batch_seqlens)
func = self.estimate_func.get(self.config.model_type, self._estimate_unknown_flops)
estimated_flops = func(tokens_sum, batch_seqlens, delta_time)
promised_flops = get_device_flops()
return estimated_flops, promised_flops
|