File size: 5,263 Bytes
5000658
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import numpy as np
import tensorrt as trt
import torch

from tensorrt_llm.logger import logger
from tensorrt_llm.network import get_plugin_info

from .shape_info import get_per_layer_graph
from .utils import get_cache_key, get_trt_network, get_updated_plugin


class NvtxProfiler(object):

    def __init__(self, nvtx_name, enable=True):
        self.nvtx_name = nvtx_name
        self.enable = enable

    def __enter__(self):
        if self.enable:
            torch.cuda.nvtx.range_push(self.nvtx_name)

    def __exit__(self, exc_type, exc_val, exc_tb):
        if self.enable:
            torch.cuda.nvtx.range_pop()


class LayerProfiler(trt.IProfiler):

    def __init__(self):
        trt.IProfiler.__init__(self)
        self.layer_count = 0
        self.time = 0

    def report_layer_time(self, layer_name, ms):
        logger.debug(f'{layer_name=}, {self.layer_count=}, time = {ms} ms')
        self.time += ms
        self.layer_count += 1


class RuntimeProfiler(object):

    def __init__(self):
        self.timing_cache = None

    def _profile(self, layer, layer_attrs, shapes, values, io_buffer_mapping):
        is_plugin = layer.type == trt.LayerType.PLUGIN_V2
        if is_plugin and len(layer_attrs) > 0:
            plugin_info = get_plugin_info(
                get_trt_network(layer),
                layer.name,
            )
            new_plugin, _ = get_updated_plugin(plugin_info, layer_attrs)
            layer_attrs = {"plugin": new_plugin}
        graph, output_mapping = get_per_layer_graph(layer, shapes, values,
                                                    layer_attrs)
        graph._io_buffer_mapping = io_buffer_mapping
        network = graph.as_trt()
        if network.num_outputs > 0 and np.all([
                network.get_output(i).is_shape_tensor
                for i in range(network.num_outputs)
        ]):
            return 0.0
        for proxy_output, output in output_mapping.items():
            shapes[proxy_output] = shapes[output]
        if not self.timing_cache:
            self.timing_cache = network.builder.create_builder_config(
            ).create_timing_cache(b"")
        runner = graph.get_runner(
            shapes,
            values,
            timing_cache=self.timing_cache,
        )
        context = runner.session.context
        context.profiler = LayerProfiler()
        runner.run()
        profiler_time_first_run = context.profiler.time
        runner.run()
        return (context.profiler.time - profiler_time_first_run) * 1000.0

    def runtime_profile(self, layer, layer_attrs, input_values, strategy,
                        device_mesh):
        logger.debug(f"start to profile layer {layer.name}")
        shapes = {}
        values = {}
        dtypes = {}
        trt_layer = layer.as_trt()

        sharding_sequences = ()
        for i in range(layer.num_inputs):
            input = trt_layer.get_input(i)
            if input is not None:
                shapes[input.name] = strategy.sharding_specs[
                    f'input{i}'].get_sharded_shape_per_device()
                dtypes[input.name] = input.dtype
                sharding_sequences += (str(
                    strategy.sharding_specs[f"input{i}"].sharding_sequence), )
                if i in input_values:
                    values[input.name] = input_values[i]
                else:
                    value = layer.get_input(i).value
                    if value is not None:
                        values[input.name] = value
            else:
                sharding_sequences += (None, )

        for i in range(layer.num_outputs):
            output = trt_layer.get_output(i)
            if f'output{i}' in strategy.communication_actions:
                shapes[output.name] = strategy.communication_actions[
                    f'output{i}'].sharding_spec.get_sharded_shape_per_device()
            else:
                shapes[output.name] = strategy.sharding_specs[
                    f'output{i}'].get_sharded_shape_per_device()
            dtypes[output.name] = output.dtype
            sharding_sequences += (str(
                strategy.sharding_specs[f"output{i}"].sharding_sequence), )
        data_key = get_cache_key(
            trt_layer,
            shapes,
            values,
            dtypes=dtypes,
            updated_attrs=layer_attrs,
        )
        data_key += (sharding_sequences, )
        elapsed_time = device_mesh.prof_database.query(
            device_mesh.cluster_key,
            data_key,
        )
        if elapsed_time:
            logger.debug(
                f'runtime profiling cache hit {data_key}: {elapsed_time} us')
            return elapsed_time
        with NvtxProfiler(f'{layer.name}_{data_key}', enable=True):
            elapsed_time = self._profile(
                layer.as_trt(),
                layer_attrs,
                shapes,
                values,
                layer.graph._io_buffer_mapping,
            )
        logger.debug(
            f'runtime profiling cache miss {data_key}: {elapsed_time} us')

        device_mesh.prof_database.update(
            device_mesh.cluster_key,
            data_key,
            (elapsed_time, strategy.alpha_beta_cost),
        )

        return elapsed_time