File size: 10,239 Bytes
5000658
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
import gc
import os
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path

import tensorrt as trt
import torch
from filelock import FileLock

from tensorrt_llm.functional import DimRange, Tensor
from tensorrt_llm.logger import logger
from tensorrt_llm.network import Network, net_guard

from .config import AutoParallelConfig
from .device_mesh import LogicalDeviceMesh, PhysicalDeviceMesh
from .node_graph import NodeGraph
from .parallelization import ParallelConfig, parallelize
from .pipeline_graph import PipelineGraph
from .simplifier import GraphConfig, Simplifier, StageType
from .utils import current_flags


def to_network(graph: PipelineGraph, network: Network):
    logger.debug("Converting graph to network")
    trt_network = graph.as_trt()
    trt_network.name = network.trt_network.name
    new_network = Network()
    new_network._init(trt_network)
    new_network._dtype = network._dtype
    new_network._plugin_config = network._plugin_config
    new_network._unfilled_weights = graph._unfilled_weights
    new_network._auto_parallel_config = graph._auto_parallel_config
    with net_guard(network):
        for i in range(trt_network.num_inputs):
            input = trt_network.get_input(i)
            tensor = Tensor(is_network_input=False)
            if input.name in network._inputs:
                profiles = network._inputs[input.name].profiles
            elif len(network._inputs) == 0:
                profiles = []
            else:
                shape = input.shape
                num_profiles = len(list(network._inputs.values())[0].profiles)
                profile = DimRange(shape, [None] * len(shape))
                profiles = [profile] * num_profiles
            tensor.profiles = profiles
            tensor.trt_tensor = input
            new_network._inputs[input.name] = tensor
    return new_network


def find_solution(
    node_graph: NodeGraph,
    graph_config: GraphConfig,
    lmesh: LogicalDeviceMesh,
    memory_budget: int,
    flags: list,
    device: int,
    dump_path: str,
) -> ParallelConfig:
    torch.cuda.set_device(device)
    with current_flags(*flags):
        cost_graph = node_graph.get_cost_graph(lmesh)
        num_stages = graph_config.num_stages
        if num_stages == 1:
            stage_types = [None]
        elif num_stages == 2:
            stage_types = [StageType.START, StageType.END]
        else:
            stage_types = [StageType.START, StageType.BLOCK, StageType.END]

        best_config, best_solution = None, None
        for stage_type in stage_types:
            if stage_type is not None:
                node_graph.set_slowest_stage(stage_type, graph_config)
            solution = node_graph.find_solution(
                cost_graph,
                memory_budget,
            )
            cost = solution.total_cost
            if best_config is None or cost < best_config.cost:
                best_config = ParallelConfig()
                best_config.graph_config = graph_config
                best_config.lmesh = lmesh
                best_config.cost = cost
                best_config.graph_strategy = solution.node_best_strategy
                best_config.stage_type = stage_type
                best_solution = solution
        if dump_path is not None:
            lock = FileLock(f"{dump_path}/path.lock", thread_local=False)
            vlz_name = f"{dump_path}/solution."
            if graph_config.num_micro_batches != 1:
                vlz_name += f"mbs{graph_config.num_micro_batches}."
            if graph_config.num_stages != 1:
                vlz_name += f"stages{graph_config.num_stages}."
            vlz_name += lmesh.cluster_key
            with lock:
                node_graph.visualize_solution(
                    best_solution,
                    vlz_name,
                    ignore_shape_io=True,
                )
        return best_config


def infer_builder_flags(network):
    fp16_enabled = False
    bf16_enabled = False
    int8_enabled = False
    fp8_enabled = False

    def check_dtype(tensor):
        nonlocal fp16_enabled
        nonlocal bf16_enabled
        nonlocal int8_enabled
        nonlocal fp8_enabled
        if tensor.dtype == trt.DataType.HALF:
            fp16_enabled = True
        elif tensor.dtype == trt.DataType.BF16:
            bf16_enabled = True
        elif tensor.dtype == trt.DataType.INT8:
            int8_enabled = True
        elif tensor.dtype == trt.DataType.FP8:
            fp8_enabled = True

    trt_network = network.trt_network
    for i in range(trt_network.num_inputs):
        input = trt_network.get_input(i)
        check_dtype(input)
    for i in range(trt_network.num_layers):
        layer = trt_network.get_layer(i)
        for j in range(layer.num_outputs):
            output = layer.get_output(j)
            check_dtype(output)

    builder_flags = 0
    if fp16_enabled:
        builder_flags |= 1 << int(trt.BuilderFlag.FP16)
        builder_flags |= 1 << int(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS)
    if bf16_enabled:
        builder_flags |= 1 << int(trt.BuilderFlag.BF16)
        builder_flags |= 1 << int(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS)
    if int8_enabled:
        builder_flags |= 1 << int(trt.BuilderFlag.INT8)
    if fp8_enabled:
        builder_flags |= 1 << int(trt.BuilderFlag.FP8)
        builder_flags |= 1 << int(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS)
    return builder_flags


def auto_parallel(network: Network, config: AutoParallelConfig):
    debug_mode = config.debug_mode
    memory_budget = config.get_cluster_info(
    ).memory_budget_per_device * 1024 * 1024 * 1024
    enable_pipeline_parallelism = config.enable_pipeline_parallelism
    if config.world_size < config.gpus_per_node:
        num_hosts = 1
        num_devices_per_host = config.world_size
    else:
        assert config.world_size % config.gpus_per_node == 0
        num_hosts = config.world_size // config.gpus_per_node
        num_devices_per_host = config.gpus_per_node
    parallel_config_cache = config.parallel_config_cache
    dump_path = config.dump_path if debug_mode else None
    fill_weights = config.fill_weights

    if num_hosts == 1 and num_devices_per_host == 1:
        return [network]

    if dump_path is not None:
        if not os.path.exists(dump_path):
            os.makedirs(dump_path)

    builder_flags = config.builder_flags or infer_builder_flags(network)
    flags = [builder_flags, network.strongly_typed]
    with current_flags(*flags):
        simplifier = Simplifier(network, config)
        network_hash = simplifier.get_network_hash()

        best_config = None
        if parallel_config_cache is not None and Path(
                parallel_config_cache).exists():
            parallel_config = ParallelConfig.from_file(parallel_config_cache)
            if (ParallelConfig.VERSION == parallel_config.version
                    and network_hash == parallel_config.network_hash
                    and config == parallel_config.auto_parallel_config):
                logger.info(
                    f"use cache of parallel config from {parallel_config_cache}"
                )
                best_config = parallel_config

        if best_config is None:
            num_devices = num_hosts * num_devices_per_host
            phy_ids = [[
                i + j * num_devices_per_host
                for i in range(num_devices_per_host)
            ] for j in range(num_hosts)]
            phy_mesh = PhysicalDeviceMesh(phy_ids, config)
            if enable_pipeline_parallelism:
                num_micro_batches_list = simplifier.list_all_num_micro_batches()
            else:
                num_micro_batches_list = [1]

            jobs = []
            for num_micro_batches in num_micro_batches_list:
                simplifier.infer_shapes(num_micro_batches)
                if enable_pipeline_parallelism:
                    pipeline_configs = phy_mesh.list_all_pipeline_configs()
                else:
                    pipeline_configs = [(1, num_devices)]
                for num_stages, num_devices_per_stage in pipeline_configs:
                    # TODO: add fallback path that allows num_micro_batches >= num_stages
                    #       if no solution satisfies memory budget
                    if num_micro_batches < num_stages:
                        continue
                    simplified_graph, graph_config = simplifier.simplify_graph(
                        phy_mesh,
                        num_stages,
                        num_devices_per_stage,
                    )
                    if simplified_graph is None:
                        continue
                    node_graph = NodeGraph(simplified_graph)
                    node_graph.assign_cost_weights(graph_config)
                    lmeshes = graph_config.stage_phy_meshes[
                        0].get_logical_meshes()
                    for lmesh in lmeshes:
                        jobs.append(
                            (node_graph, graph_config, lmesh, memory_budget *
                             (num_devices / num_devices_per_stage)))

            try:
                with ThreadPoolExecutor() as executor:
                    best_config = sorted(
                        executor.map(
                            lambda x: find_solution(
                                *x,
                                flags,
                                torch.cuda.current_device(),
                                dump_path,
                            ),
                            jobs,
                        ),
                        key=lambda x: x.cost,
                    )[0]
            finally:
                phy_mesh.close()

            if parallel_config_cache is not None:
                best_config.network_hash = network_hash
                best_config.auto_parallel_config = config
                best_config.save(parallel_config_cache)

        new_graphs = parallelize(simplifier, best_config)

    networks = [to_network(new_graph, network) for new_graph in new_graphs]
    if debug_mode and fill_weights:
        networks[0]._fill_weights()

    gc.collect()
    torch.cuda.empty_cache()

    return networks