| import numpy as np | |
| import tensorrt as trt | |
| import torch | |
| from tensorrt_llm.logger import logger | |
| from tensorrt_llm.network import get_plugin_info | |
| from .shape_info import get_per_layer_graph | |
| from .utils import get_cache_key, get_trt_network, get_updated_plugin | |
| class NvtxProfiler(object): | |
| def __init__(self, nvtx_name, enable=True): | |
| self.nvtx_name = nvtx_name | |
| self.enable = enable | |
| def __enter__(self): | |
| if self.enable: | |
| torch.cuda.nvtx.range_push(self.nvtx_name) | |
| def __exit__(self, exc_type, exc_val, exc_tb): | |
| if self.enable: | |
| torch.cuda.nvtx.range_pop() | |
| class LayerProfiler(trt.IProfiler): | |
| def __init__(self): | |
| trt.IProfiler.__init__(self) | |
| self.layer_count = 0 | |
| self.time = 0 | |
| def report_layer_time(self, layer_name, ms): | |
| logger.debug(f'{layer_name=}, {self.layer_count=}, time = {ms} ms') | |
| self.time += ms | |
| self.layer_count += 1 | |
| class RuntimeProfiler(object): | |
| def __init__(self): | |
| self.timing_cache = None | |
| def _profile(self, layer, layer_attrs, shapes, values, io_buffer_mapping): | |
| is_plugin = layer.type == trt.LayerType.PLUGIN_V2 | |
| if is_plugin and len(layer_attrs) > 0: | |
| plugin_info = get_plugin_info( | |
| get_trt_network(layer), | |
| layer.name, | |
| ) | |
| new_plugin, _ = get_updated_plugin(plugin_info, layer_attrs) | |
| layer_attrs = {"plugin": new_plugin} | |
| graph, output_mapping = get_per_layer_graph(layer, shapes, values, | |
| layer_attrs) | |
| graph._io_buffer_mapping = io_buffer_mapping | |
| network = graph.as_trt() | |
| if network.num_outputs > 0 and np.all([ | |
| network.get_output(i).is_shape_tensor | |
| for i in range(network.num_outputs) | |
| ]): | |
| return 0.0 | |
| for proxy_output, output in output_mapping.items(): | |
| shapes[proxy_output] = shapes[output] | |
| if not self.timing_cache: | |
| self.timing_cache = network.builder.create_builder_config( | |
| ).create_timing_cache(b"") | |
| runner = graph.get_runner( | |
| shapes, | |
| values, | |
| timing_cache=self.timing_cache, | |
| ) | |
| context = runner.session.context | |
| context.profiler = LayerProfiler() | |
| runner.run() | |
| profiler_time_first_run = context.profiler.time | |
| runner.run() | |
| return (context.profiler.time - profiler_time_first_run) * 1000.0 | |
| def runtime_profile(self, layer, layer_attrs, input_values, strategy, | |
| device_mesh): | |
| logger.debug(f"start to profile layer {layer.name}") | |
| shapes = {} | |
| values = {} | |
| dtypes = {} | |
| trt_layer = layer.as_trt() | |
| sharding_sequences = () | |
| for i in range(layer.num_inputs): | |
| input = trt_layer.get_input(i) | |
| if input is not None: | |
| shapes[input.name] = strategy.sharding_specs[ | |
| f'input{i}'].get_sharded_shape_per_device() | |
| dtypes[input.name] = input.dtype | |
| sharding_sequences += (str( | |
| strategy.sharding_specs[f"input{i}"].sharding_sequence), ) | |
| if i in input_values: | |
| values[input.name] = input_values[i] | |
| else: | |
| value = layer.get_input(i).value | |
| if value is not None: | |
| values[input.name] = value | |
| else: | |
| sharding_sequences += (None, ) | |
| for i in range(layer.num_outputs): | |
| output = trt_layer.get_output(i) | |
| if f'output{i}' in strategy.communication_actions: | |
| shapes[output.name] = strategy.communication_actions[ | |
| f'output{i}'].sharding_spec.get_sharded_shape_per_device() | |
| else: | |
| shapes[output.name] = strategy.sharding_specs[ | |
| f'output{i}'].get_sharded_shape_per_device() | |
| dtypes[output.name] = output.dtype | |
| sharding_sequences += (str( | |
| strategy.sharding_specs[f"output{i}"].sharding_sequence), ) | |
| data_key = get_cache_key( | |
| trt_layer, | |
| shapes, | |
| values, | |
| dtypes=dtypes, | |
| updated_attrs=layer_attrs, | |
| ) | |
| data_key += (sharding_sequences, ) | |
| elapsed_time = device_mesh.prof_database.query( | |
| device_mesh.cluster_key, | |
| data_key, | |
| ) | |
| if elapsed_time: | |
| logger.debug( | |
| f'runtime profiling cache hit {data_key}: {elapsed_time} us') | |
| return elapsed_time | |
| with NvtxProfiler(f'{layer.name}_{data_key}', enable=True): | |
| elapsed_time = self._profile( | |
| layer.as_trt(), | |
| layer_attrs, | |
| shapes, | |
| values, | |
| layer.graph._io_buffer_mapping, | |
| ) | |
| logger.debug( | |
| f'runtime profiling cache miss {data_key}: {elapsed_time} us') | |
| device_mesh.prof_database.update( | |
| device_mesh.cluster_key, | |
| data_key, | |
| (elapsed_time, strategy.alpha_beta_cost), | |
| ) | |
| return elapsed_time | |