|
|
import copy |
|
|
import re |
|
|
from dataclasses import dataclass, field |
|
|
from typing import Dict, Tuple, Union |
|
|
|
|
|
import pynvml |
|
|
import torch |
|
|
from cuda import cudart |
|
|
|
|
|
from tensorrt_llm._utils import DictConversion |
|
|
from tensorrt_llm.logger import logger |
|
|
from tensorrt_llm.profiler import PyNVMLContext, _device_get_memory_info_fn |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class MathThroughput(DictConversion): |
|
|
int4: int = 0 |
|
|
int8: int = 0 |
|
|
fp8: int = 0 |
|
|
float16: int = 0 |
|
|
bfloat16: int = 0 |
|
|
float32: int = 0 |
|
|
|
|
|
@staticmethod |
|
|
def to_tflops( |
|
|
ipc_per_sm: "MathThroughput", |
|
|
sm_count: int, |
|
|
clock_mhz: int, |
|
|
) -> "MathThroughput": |
|
|
tflops = MathThroughput() |
|
|
for name in ipc_per_sm.__dataclass_fields__: |
|
|
setattr( |
|
|
tflops, name, |
|
|
getattr(ipc_per_sm, name) * sm_count * clock_mhz // int(1e6)) |
|
|
return tflops |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ClusterInfo(DictConversion): |
|
|
inter_node_bw_per_device: int = 25 |
|
|
intra_node_bw_per_device: int = 0 |
|
|
inter_node_latency: int = 10 |
|
|
intra_node_latency: int = 10 |
|
|
intra_node_sharp: bool = False |
|
|
inter_node_sharp: bool = True |
|
|
|
|
|
memory_bw: int = 0 |
|
|
memory_budget_per_device: int = 0 |
|
|
|
|
|
math_throughput: MathThroughput = field(default_factory=MathThroughput) |
|
|
|
|
|
memory_efficiency: float = 1.0 |
|
|
math_efficiency: float = 1.0 |
|
|
communication_efficiency: float = 1.0 |
|
|
|
|
|
|
|
|
_math_throughputs = { |
|
|
"A100": MathThroughput( |
|
|
int8=624, |
|
|
float16=312, |
|
|
bfloat16=312, |
|
|
float32=156, |
|
|
), |
|
|
} |
|
|
|
|
|
_bandwidths = { |
|
|
"PCIe-3": 16, |
|
|
"PCIe-4": 32, |
|
|
"PCIe-5": 64, |
|
|
} |
|
|
|
|
|
cluster_infos = { |
|
|
|
|
|
"A100-SXM-80GB": |
|
|
ClusterInfo( |
|
|
intra_node_bw_per_device=300, |
|
|
memory_bw=2039, |
|
|
memory_budget_per_device=80, |
|
|
math_throughput=_math_throughputs["A100"], |
|
|
), |
|
|
"A100-SXM-40GB": |
|
|
ClusterInfo( |
|
|
intra_node_bw_per_device=300, |
|
|
memory_bw=1555, |
|
|
memory_budget_per_device=40, |
|
|
math_throughput=_math_throughputs["A100"], |
|
|
), |
|
|
"A100-PCIe-80GB": |
|
|
ClusterInfo( |
|
|
intra_node_bw_per_device=_bandwidths["PCIe-4"], |
|
|
memory_bw=1935, |
|
|
memory_budget_per_device=80, |
|
|
math_throughput=_math_throughputs["A100"], |
|
|
), |
|
|
"A100-PCIe-40GB": |
|
|
ClusterInfo( |
|
|
intra_node_bw_per_device=_bandwidths["PCIe-4"], |
|
|
memory_bw=1555, |
|
|
memory_budget_per_device=40, |
|
|
math_throughput=_math_throughputs["A100"], |
|
|
), |
|
|
|
|
|
"H100-SXM": |
|
|
ClusterInfo( |
|
|
inter_node_bw_per_device=50, |
|
|
intra_node_bw_per_device=450, |
|
|
intra_node_sharp=True, |
|
|
memory_bw=3350, |
|
|
memory_budget_per_device=80, |
|
|
math_throughput=MathThroughput( |
|
|
int8=1979, |
|
|
fp8=1979, |
|
|
float16=989, |
|
|
bfloat16=989, |
|
|
float32=495, |
|
|
), |
|
|
), |
|
|
"H100-PCIe": |
|
|
ClusterInfo( |
|
|
inter_node_bw_per_device=50, |
|
|
intra_node_bw_per_device=_bandwidths["PCIe-5"], |
|
|
memory_bw=2000, |
|
|
memory_budget_per_device=80, |
|
|
math_throughput=MathThroughput( |
|
|
int8=1513, |
|
|
fp8=1513, |
|
|
float16=756, |
|
|
bfloat16=756, |
|
|
float32=378, |
|
|
), |
|
|
), |
|
|
"H20": |
|
|
ClusterInfo( |
|
|
inter_node_bw_per_device=50, |
|
|
intra_node_bw_per_device=450, |
|
|
memory_bw=4000, |
|
|
memory_budget_per_device=96, |
|
|
math_throughput=MathThroughput( |
|
|
int8=293, |
|
|
fp8=293, |
|
|
float16=147, |
|
|
bfloat16=147, |
|
|
float32=74, |
|
|
), |
|
|
), |
|
|
|
|
|
"V100-PCIe-16GB": |
|
|
ClusterInfo( |
|
|
intra_node_bw_per_device=_bandwidths["PCIe-3"], |
|
|
memory_bw=900, |
|
|
memory_budget_per_device=16, |
|
|
math_throughput=MathThroughput(float32=112), |
|
|
), |
|
|
"V100-PCIe-32GB": |
|
|
ClusterInfo( |
|
|
intra_node_bw_per_device=_bandwidths["PCIe-3"], |
|
|
memory_bw=900, |
|
|
memory_budget_per_device=32, |
|
|
math_throughput=MathThroughput(float32=112), |
|
|
), |
|
|
"V100-SXM-16GB": |
|
|
ClusterInfo( |
|
|
intra_node_bw_per_device=150, |
|
|
memory_bw=900, |
|
|
memory_budget_per_device=16, |
|
|
math_throughput=MathThroughput(float32=125), |
|
|
), |
|
|
"V100-SXM-32GB": |
|
|
ClusterInfo( |
|
|
intra_node_bw_per_device=150, |
|
|
memory_bw=900, |
|
|
memory_budget_per_device=32, |
|
|
math_throughput=MathThroughput(float32=125), |
|
|
), |
|
|
"V100S-PCIe": |
|
|
ClusterInfo( |
|
|
intra_node_bw_per_device=_bandwidths["PCIe-3"], |
|
|
memory_bw=1134, |
|
|
memory_budget_per_device=32, |
|
|
math_throughput=MathThroughput(float32=130), |
|
|
), |
|
|
|
|
|
"A40": |
|
|
ClusterInfo( |
|
|
intra_node_bw_per_device=_bandwidths["PCIe-4"], |
|
|
memory_bw=696, |
|
|
memory_budget_per_device=48, |
|
|
math_throughput=MathThroughput( |
|
|
int4=600, |
|
|
int8=300, |
|
|
float16=150, |
|
|
bfloat16=150, |
|
|
float32=75, |
|
|
), |
|
|
), |
|
|
|
|
|
"A30": |
|
|
ClusterInfo( |
|
|
intra_node_bw_per_device=_bandwidths["PCIe-4"], |
|
|
memory_bw=933, |
|
|
memory_budget_per_device=24, |
|
|
math_throughput=MathThroughput( |
|
|
int4=661, |
|
|
int8=330, |
|
|
float16=165, |
|
|
bfloat16=165, |
|
|
float32=82, |
|
|
), |
|
|
), |
|
|
|
|
|
"A10": |
|
|
ClusterInfo( |
|
|
intra_node_bw_per_device=_bandwidths["PCIe-4"], |
|
|
memory_bw=600, |
|
|
memory_budget_per_device=24, |
|
|
math_throughput=MathThroughput( |
|
|
int4=500, |
|
|
int8=250, |
|
|
float16=125, |
|
|
bfloat16=125, |
|
|
float32=62.5, |
|
|
), |
|
|
), |
|
|
"A10G": |
|
|
ClusterInfo( |
|
|
intra_node_bw_per_device=_bandwidths["PCIe-4"], |
|
|
memory_bw=600, |
|
|
memory_budget_per_device=24, |
|
|
math_throughput=MathThroughput( |
|
|
int4=280, |
|
|
int8=140, |
|
|
float16=70, |
|
|
bfloat16=70, |
|
|
float32=35, |
|
|
), |
|
|
), |
|
|
|
|
|
"L40S": |
|
|
ClusterInfo( |
|
|
intra_node_bw_per_device=_bandwidths["PCIe-4"], |
|
|
memory_bw=864, |
|
|
memory_budget_per_device=48, |
|
|
math_throughput=MathThroughput( |
|
|
int4=733, |
|
|
int8=733, |
|
|
fp8=733, |
|
|
float16=362, |
|
|
bfloat16=362, |
|
|
float32=183, |
|
|
), |
|
|
), |
|
|
|
|
|
"L40": |
|
|
ClusterInfo( |
|
|
intra_node_bw_per_device=_bandwidths["PCIe-4"], |
|
|
memory_bw=864, |
|
|
memory_budget_per_device=48, |
|
|
math_throughput=MathThroughput( |
|
|
int4=724, |
|
|
int8=362, |
|
|
fp8=362, |
|
|
float16=181, |
|
|
bfloat16=181, |
|
|
float32=90, |
|
|
), |
|
|
), |
|
|
"L20": |
|
|
ClusterInfo( |
|
|
intra_node_bw_per_device=_bandwidths["PCIe-4"], |
|
|
memory_bw=864, |
|
|
memory_budget_per_device=48, |
|
|
math_throughput=MathThroughput( |
|
|
int8=238, |
|
|
fp8=238, |
|
|
float16=119, |
|
|
bfloat16=119, |
|
|
float32=60, |
|
|
), |
|
|
), |
|
|
|
|
|
"L4": |
|
|
ClusterInfo( |
|
|
intra_node_bw_per_device=_bandwidths["PCIe-4"], |
|
|
memory_bw=300, |
|
|
memory_budget_per_device=24, |
|
|
math_throughput=MathThroughput( |
|
|
int8=242, |
|
|
fp8=242, |
|
|
float16=120, |
|
|
bfloat16=120, |
|
|
float32=60, |
|
|
), |
|
|
), |
|
|
"L2": |
|
|
ClusterInfo( |
|
|
intra_node_bw_per_device=_bandwidths["PCIe-4"], |
|
|
memory_bw=300, |
|
|
memory_budget_per_device=24, |
|
|
math_throughput=MathThroughput( |
|
|
int8=193, |
|
|
fp8=193, |
|
|
float16=97, |
|
|
bfloat16=97, |
|
|
float32=48, |
|
|
), |
|
|
), |
|
|
} |
|
|
|
|
|
|
|
|
def infer_cluster_key() -> str: |
|
|
|
|
|
def match(product, name): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return re.match(f".*{product}([ -]|$).*", name) is not None |
|
|
|
|
|
def is_sxm(): |
|
|
return "SXM" in device_name |
|
|
|
|
|
def is_80gb(): |
|
|
return "80GB" in device_name |
|
|
|
|
|
def is_32gb(): |
|
|
return "32GB" in device_name |
|
|
|
|
|
device_name = torch.cuda.get_device_name(torch.cuda.current_device()) |
|
|
|
|
|
if match("A100", device_name): |
|
|
if is_sxm(): |
|
|
if is_80gb(): |
|
|
return "A100-SXM-80GB" |
|
|
else: |
|
|
return "A100-SXM-40GB" |
|
|
else: |
|
|
if is_80gb(): |
|
|
return "A100-PCIe-80GB" |
|
|
else: |
|
|
return "A100-PCIe-40GB" |
|
|
elif match("A10G", device_name): |
|
|
return "A10G" |
|
|
elif match("A10", device_name): |
|
|
return "A10" |
|
|
elif match("A30", device_name): |
|
|
return "A30" |
|
|
elif match("A40", device_name): |
|
|
return "A40" |
|
|
elif match("H100", device_name): |
|
|
if is_sxm(): |
|
|
return "H100-SXM" |
|
|
else: |
|
|
return "H100-PCIe" |
|
|
elif match("L40S", device_name): |
|
|
return "L40S" |
|
|
elif match("L40", device_name): |
|
|
return "L40" |
|
|
elif match("L4", device_name): |
|
|
return "L4" |
|
|
elif match("V100S", device_name): |
|
|
return "V100S-PCIe" |
|
|
elif match("V100", device_name): |
|
|
if is_sxm(): |
|
|
if is_32gb(): |
|
|
return "V100-SXM-32GB" |
|
|
else: |
|
|
return "V100-SXM-16GB" |
|
|
else: |
|
|
if is_32gb(): |
|
|
return "V100-PCIe-32GB" |
|
|
else: |
|
|
return "V100-PCIe-16GB" |
|
|
return None |
|
|
|
|
|
|
|
|
def ipc_per_sm(compute_cap: Tuple[int, int]) -> MathThroughput: |
|
|
ipc_table = { |
|
|
(9, 0): |
|
|
MathThroughput( |
|
|
int8=16384, |
|
|
fp8=16384, |
|
|
float16=8192, |
|
|
bfloat16=8192, |
|
|
float32=4096, |
|
|
), |
|
|
(8, 0): |
|
|
MathThroughput( |
|
|
int4=8192, |
|
|
int8=4096, |
|
|
float16=2048, |
|
|
bfloat16=2048, |
|
|
float32=1024, |
|
|
), |
|
|
(8, 6): |
|
|
MathThroughput( |
|
|
int4=4096, |
|
|
int8=2048, |
|
|
float16=1024, |
|
|
bfloat16=1024, |
|
|
float32=512, |
|
|
), |
|
|
(8, 9): |
|
|
MathThroughput( |
|
|
int4=2048, |
|
|
int8=1024, |
|
|
fp8=1024, |
|
|
float16=512, |
|
|
bfloat16=512, |
|
|
float32=256, |
|
|
), |
|
|
(7, 0): |
|
|
MathThroughput( |
|
|
float16=1024, |
|
|
float32=128, |
|
|
), |
|
|
(7, 5): |
|
|
MathThroughput( |
|
|
int4=4096, |
|
|
int8=2048, |
|
|
float16=1024, |
|
|
float32=128, |
|
|
), |
|
|
} |
|
|
return ipc_table.get(compute_cap, MathThroughput()) |
|
|
|
|
|
|
|
|
def nvlink_version(version_enum: int) -> int: |
|
|
nvl_version_table = { |
|
|
1: 1, |
|
|
2: 2, |
|
|
3: 2, |
|
|
4: 2, |
|
|
5: 3, |
|
|
6: 3, |
|
|
7: 4, |
|
|
} |
|
|
return nvl_version_table[version_enum] |
|
|
|
|
|
|
|
|
def nvlink_bandwidth(nvlink_version: int) -> int: |
|
|
nvl_bw_table = { |
|
|
1: 80, |
|
|
2: 150, |
|
|
3: 300, |
|
|
4: 450, |
|
|
} |
|
|
return nvl_bw_table[nvlink_version] |
|
|
|
|
|
|
|
|
def infer_cluster_info() -> ClusterInfo: |
|
|
device = torch.cuda.current_device() |
|
|
index = device.index if isinstance(device, torch.device) else device |
|
|
with PyNVMLContext(): |
|
|
handle = pynvml.nvmlDeviceGetHandleByIndex(index) |
|
|
compute_cap = pynvml.nvmlDeviceGetCudaComputeCapability(handle) |
|
|
logger.info(f"Compute capability: {compute_cap}") |
|
|
err, properties = cudart.cudaGetDeviceProperties(index) |
|
|
sm_count = properties.multiProcessorCount |
|
|
logger.info(f"SM count: {sm_count}") |
|
|
sm_clock = pynvml.nvmlDeviceGetMaxClockInfo( |
|
|
handle, |
|
|
pynvml.NVML_CLOCK_SM, |
|
|
) |
|
|
logger.info(f"SM clock: {sm_clock} MHz") |
|
|
math_throughput = MathThroughput.to_tflops( |
|
|
ipc_per_sm(compute_cap), |
|
|
sm_count, |
|
|
sm_clock, |
|
|
) |
|
|
for name in math_throughput.__dataclass_fields__: |
|
|
tflops = getattr(math_throughput, name) |
|
|
logger.info(f"{name} TFLOPS: {tflops}") |
|
|
|
|
|
mem_info = _device_get_memory_info_fn(handle) |
|
|
memory_budget = mem_info.total // (1024**3) |
|
|
logger.info(f"Total Memory: {memory_budget} GiB") |
|
|
|
|
|
mem_clock = pynvml.nvmlDeviceGetMaxClockInfo( |
|
|
handle, |
|
|
pynvml.NVML_CLOCK_MEM, |
|
|
) |
|
|
logger.info(f"Memory clock: {mem_clock} MHz") |
|
|
if pynvml.__version__ < '11.5.0': |
|
|
mem_bus_width = properties.memoryBusWidth |
|
|
else: |
|
|
mem_bus_width = pynvml.nvmlDeviceGetMemoryBusWidth(handle) |
|
|
logger.info(f"Memory bus width: {mem_bus_width}") |
|
|
memory_bw = mem_bus_width * mem_clock * 2 // int(8e3) |
|
|
logger.info(f"Memory bandwidth: {memory_bw} GB/s") |
|
|
|
|
|
try: |
|
|
is_nvl_active = bool(pynvml.nvmlDeviceGetNvLinkState(handle, 0)) |
|
|
logger.info(f"NVLink is active: {is_nvl_active}") |
|
|
except pynvml.NVMLError: |
|
|
is_nvl_active = False |
|
|
|
|
|
intra_node_sharp = False |
|
|
if is_nvl_active: |
|
|
nvl_version_enum = pynvml.nvmlDeviceGetNvLinkVersion(handle, 0) |
|
|
nvl_version = nvlink_version(nvl_version_enum) |
|
|
logger.info(f"NVLink version: {nvl_version}") |
|
|
nvl_bw = nvlink_bandwidth(nvl_version) |
|
|
logger.info(f"NVLink bandwidth: {nvl_bw} GB/s") |
|
|
intra_node_bw = nvl_bw |
|
|
if nvl_version >= 4: |
|
|
intra_node_sharp = True |
|
|
else: |
|
|
if pynvml.__version__ < '11.5.0': |
|
|
pcie_gen = pynvml.nvmlDeviceGetCurrPcieLinkGeneration(handle) |
|
|
pcie_speed = (2**pcie_gen) * 1000 |
|
|
else: |
|
|
pcie_speed = pynvml.nvmlDeviceGetPcieSpeed(handle) |
|
|
logger.info(f"PCIe speed: {pcie_speed} Mbps") |
|
|
pcie_link_width = pynvml.nvmlDeviceGetCurrPcieLinkWidth(handle) |
|
|
logger.info(f"PCIe link width: {pcie_link_width}") |
|
|
pcie_bw = pcie_speed * pcie_link_width // int(8e3) |
|
|
logger.info(f"PCIe bandwidth: {pcie_bw} GB/s") |
|
|
intra_node_bw = pcie_bw |
|
|
|
|
|
cluster_info = ClusterInfo( |
|
|
math_throughput=math_throughput, |
|
|
memory_bw=memory_bw, |
|
|
memory_budget_per_device=memory_budget, |
|
|
intra_node_bw_per_device=intra_node_bw, |
|
|
intra_node_sharp=intra_node_sharp, |
|
|
) |
|
|
return cluster_info |
|
|
|
|
|
|
|
|
def infer_cluster_config() -> Dict[str, Union[str, ClusterInfo]]: |
|
|
device_name = torch.cuda.get_device_name(torch.cuda.current_device()) |
|
|
cluster_key = infer_cluster_key() |
|
|
if cluster_key is not None: |
|
|
return dict(cluster_key=cluster_key) |
|
|
else: |
|
|
try: |
|
|
cluster_info = infer_cluster_info() |
|
|
except pynvml.NVMLError: |
|
|
fallback_cluster_key = "L40" |
|
|
cluster_info = copy.copy(cluster_infos[fallback_cluster_key]) |
|
|
memory_budget = torch.cuda.mem_get_info()[1] // (1024**3) |
|
|
cluster_info.memory_budget_per_device = memory_budget |
|
|
logger.warning( |
|
|
f"Failed to infer cluster info for {device_name}, " |
|
|
f"treat it as a {fallback_cluster_key} node with {memory_budget} GB memory. " |
|
|
"This setting makes no effect if you do not use auto parallel.") |
|
|
return dict( |
|
|
cluster_key=device_name.replace(" ", "-"), |
|
|
cluster_info=cluster_info, |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
logger.set_level("info") |
|
|
infer_cluster_info() |
|
|
|