|
|
import gc |
|
|
import os |
|
|
from concurrent.futures import ThreadPoolExecutor |
|
|
from pathlib import Path |
|
|
|
|
|
import tensorrt as trt |
|
|
import torch |
|
|
from filelock import FileLock |
|
|
|
|
|
from tensorrt_llm.functional import DimRange, Tensor |
|
|
from tensorrt_llm.logger import logger |
|
|
from tensorrt_llm.network import Network, net_guard |
|
|
|
|
|
from .config import AutoParallelConfig |
|
|
from .device_mesh import LogicalDeviceMesh, PhysicalDeviceMesh |
|
|
from .node_graph import NodeGraph |
|
|
from .parallelization import ParallelConfig, parallelize |
|
|
from .pipeline_graph import PipelineGraph |
|
|
from .simplifier import GraphConfig, Simplifier, StageType |
|
|
from .utils import current_flags |
|
|
|
|
|
|
|
|
def to_network(graph: PipelineGraph, network: Network): |
|
|
logger.debug("Converting graph to network") |
|
|
trt_network = graph.as_trt() |
|
|
trt_network.name = network.trt_network.name |
|
|
new_network = Network() |
|
|
new_network._init(trt_network) |
|
|
new_network._dtype = network._dtype |
|
|
new_network._plugin_config = network._plugin_config |
|
|
new_network._unfilled_weights = graph._unfilled_weights |
|
|
new_network._auto_parallel_config = graph._auto_parallel_config |
|
|
with net_guard(network): |
|
|
for i in range(trt_network.num_inputs): |
|
|
input = trt_network.get_input(i) |
|
|
tensor = Tensor(is_network_input=False) |
|
|
if input.name in network._inputs: |
|
|
profiles = network._inputs[input.name].profiles |
|
|
elif len(network._inputs) == 0: |
|
|
profiles = [] |
|
|
else: |
|
|
shape = input.shape |
|
|
num_profiles = len(list(network._inputs.values())[0].profiles) |
|
|
profile = DimRange(shape, [None] * len(shape)) |
|
|
profiles = [profile] * num_profiles |
|
|
tensor.profiles = profiles |
|
|
tensor.trt_tensor = input |
|
|
new_network._inputs[input.name] = tensor |
|
|
return new_network |
|
|
|
|
|
|
|
|
def find_solution( |
|
|
node_graph: NodeGraph, |
|
|
graph_config: GraphConfig, |
|
|
lmesh: LogicalDeviceMesh, |
|
|
memory_budget: int, |
|
|
flags: list, |
|
|
device: int, |
|
|
dump_path: str, |
|
|
) -> ParallelConfig: |
|
|
torch.cuda.set_device(device) |
|
|
with current_flags(*flags): |
|
|
cost_graph = node_graph.get_cost_graph(lmesh) |
|
|
num_stages = graph_config.num_stages |
|
|
if num_stages == 1: |
|
|
stage_types = [None] |
|
|
elif num_stages == 2: |
|
|
stage_types = [StageType.START, StageType.END] |
|
|
else: |
|
|
stage_types = [StageType.START, StageType.BLOCK, StageType.END] |
|
|
|
|
|
best_config, best_solution = None, None |
|
|
for stage_type in stage_types: |
|
|
if stage_type is not None: |
|
|
node_graph.set_slowest_stage(stage_type, graph_config) |
|
|
solution = node_graph.find_solution( |
|
|
cost_graph, |
|
|
memory_budget, |
|
|
) |
|
|
cost = solution.total_cost |
|
|
if best_config is None or cost < best_config.cost: |
|
|
best_config = ParallelConfig() |
|
|
best_config.graph_config = graph_config |
|
|
best_config.lmesh = lmesh |
|
|
best_config.cost = cost |
|
|
best_config.graph_strategy = solution.node_best_strategy |
|
|
best_config.stage_type = stage_type |
|
|
best_solution = solution |
|
|
if dump_path is not None: |
|
|
lock = FileLock(f"{dump_path}/path.lock", thread_local=False) |
|
|
vlz_name = f"{dump_path}/solution." |
|
|
if graph_config.num_micro_batches != 1: |
|
|
vlz_name += f"mbs{graph_config.num_micro_batches}." |
|
|
if graph_config.num_stages != 1: |
|
|
vlz_name += f"stages{graph_config.num_stages}." |
|
|
vlz_name += lmesh.cluster_key |
|
|
with lock: |
|
|
node_graph.visualize_solution( |
|
|
best_solution, |
|
|
vlz_name, |
|
|
ignore_shape_io=True, |
|
|
) |
|
|
return best_config |
|
|
|
|
|
|
|
|
def infer_builder_flags(network): |
|
|
fp16_enabled = False |
|
|
bf16_enabled = False |
|
|
int8_enabled = False |
|
|
fp8_enabled = False |
|
|
|
|
|
def check_dtype(tensor): |
|
|
nonlocal fp16_enabled |
|
|
nonlocal bf16_enabled |
|
|
nonlocal int8_enabled |
|
|
nonlocal fp8_enabled |
|
|
if tensor.dtype == trt.DataType.HALF: |
|
|
fp16_enabled = True |
|
|
elif tensor.dtype == trt.DataType.BF16: |
|
|
bf16_enabled = True |
|
|
elif tensor.dtype == trt.DataType.INT8: |
|
|
int8_enabled = True |
|
|
elif tensor.dtype == trt.DataType.FP8: |
|
|
fp8_enabled = True |
|
|
|
|
|
trt_network = network.trt_network |
|
|
for i in range(trt_network.num_inputs): |
|
|
input = trt_network.get_input(i) |
|
|
check_dtype(input) |
|
|
for i in range(trt_network.num_layers): |
|
|
layer = trt_network.get_layer(i) |
|
|
for j in range(layer.num_outputs): |
|
|
output = layer.get_output(j) |
|
|
check_dtype(output) |
|
|
|
|
|
builder_flags = 0 |
|
|
if fp16_enabled: |
|
|
builder_flags |= 1 << int(trt.BuilderFlag.FP16) |
|
|
builder_flags |= 1 << int(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS) |
|
|
if bf16_enabled: |
|
|
builder_flags |= 1 << int(trt.BuilderFlag.BF16) |
|
|
builder_flags |= 1 << int(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS) |
|
|
if int8_enabled: |
|
|
builder_flags |= 1 << int(trt.BuilderFlag.INT8) |
|
|
if fp8_enabled: |
|
|
builder_flags |= 1 << int(trt.BuilderFlag.FP8) |
|
|
builder_flags |= 1 << int(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS) |
|
|
return builder_flags |
|
|
|
|
|
|
|
|
def auto_parallel(network: Network, config: AutoParallelConfig): |
|
|
debug_mode = config.debug_mode |
|
|
memory_budget = config.get_cluster_info( |
|
|
).memory_budget_per_device * 1024 * 1024 * 1024 |
|
|
enable_pipeline_parallelism = config.enable_pipeline_parallelism |
|
|
if config.world_size < config.gpus_per_node: |
|
|
num_hosts = 1 |
|
|
num_devices_per_host = config.world_size |
|
|
else: |
|
|
assert config.world_size % config.gpus_per_node == 0 |
|
|
num_hosts = config.world_size // config.gpus_per_node |
|
|
num_devices_per_host = config.gpus_per_node |
|
|
parallel_config_cache = config.parallel_config_cache |
|
|
dump_path = config.dump_path if debug_mode else None |
|
|
fill_weights = config.fill_weights |
|
|
|
|
|
if num_hosts == 1 and num_devices_per_host == 1: |
|
|
return [network] |
|
|
|
|
|
if dump_path is not None: |
|
|
if not os.path.exists(dump_path): |
|
|
os.makedirs(dump_path) |
|
|
|
|
|
builder_flags = config.builder_flags or infer_builder_flags(network) |
|
|
flags = [builder_flags, network.strongly_typed] |
|
|
with current_flags(*flags): |
|
|
simplifier = Simplifier(network, config) |
|
|
network_hash = simplifier.get_network_hash() |
|
|
|
|
|
best_config = None |
|
|
if parallel_config_cache is not None and Path( |
|
|
parallel_config_cache).exists(): |
|
|
parallel_config = ParallelConfig.from_file(parallel_config_cache) |
|
|
if (ParallelConfig.VERSION == parallel_config.version |
|
|
and network_hash == parallel_config.network_hash |
|
|
and config == parallel_config.auto_parallel_config): |
|
|
logger.info( |
|
|
f"use cache of parallel config from {parallel_config_cache}" |
|
|
) |
|
|
best_config = parallel_config |
|
|
|
|
|
if best_config is None: |
|
|
num_devices = num_hosts * num_devices_per_host |
|
|
phy_ids = [[ |
|
|
i + j * num_devices_per_host |
|
|
for i in range(num_devices_per_host) |
|
|
] for j in range(num_hosts)] |
|
|
phy_mesh = PhysicalDeviceMesh(phy_ids, config) |
|
|
if enable_pipeline_parallelism: |
|
|
num_micro_batches_list = simplifier.list_all_num_micro_batches() |
|
|
else: |
|
|
num_micro_batches_list = [1] |
|
|
|
|
|
jobs = [] |
|
|
for num_micro_batches in num_micro_batches_list: |
|
|
simplifier.infer_shapes(num_micro_batches) |
|
|
if enable_pipeline_parallelism: |
|
|
pipeline_configs = phy_mesh.list_all_pipeline_configs() |
|
|
else: |
|
|
pipeline_configs = [(1, num_devices)] |
|
|
for num_stages, num_devices_per_stage in pipeline_configs: |
|
|
|
|
|
|
|
|
if num_micro_batches < num_stages: |
|
|
continue |
|
|
simplified_graph, graph_config = simplifier.simplify_graph( |
|
|
phy_mesh, |
|
|
num_stages, |
|
|
num_devices_per_stage, |
|
|
) |
|
|
if simplified_graph is None: |
|
|
continue |
|
|
node_graph = NodeGraph(simplified_graph) |
|
|
node_graph.assign_cost_weights(graph_config) |
|
|
lmeshes = graph_config.stage_phy_meshes[ |
|
|
0].get_logical_meshes() |
|
|
for lmesh in lmeshes: |
|
|
jobs.append( |
|
|
(node_graph, graph_config, lmesh, memory_budget * |
|
|
(num_devices / num_devices_per_stage))) |
|
|
|
|
|
try: |
|
|
with ThreadPoolExecutor() as executor: |
|
|
best_config = sorted( |
|
|
executor.map( |
|
|
lambda x: find_solution( |
|
|
*x, |
|
|
flags, |
|
|
torch.cuda.current_device(), |
|
|
dump_path, |
|
|
), |
|
|
jobs, |
|
|
), |
|
|
key=lambda x: x.cost, |
|
|
)[0] |
|
|
finally: |
|
|
phy_mesh.close() |
|
|
|
|
|
if parallel_config_cache is not None: |
|
|
best_config.network_hash = network_hash |
|
|
best_config.auto_parallel_config = config |
|
|
best_config.save(parallel_config_cache) |
|
|
|
|
|
new_graphs = parallelize(simplifier, best_config) |
|
|
|
|
|
networks = [to_network(new_graph, network) for new_graph in new_graphs] |
|
|
if debug_mode and fill_weights: |
|
|
networks[0]._fill_weights() |
|
|
|
|
|
gc.collect() |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
return networks |
|
|
|