File size: 11,948 Bytes
c31821c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 | #
# Copyright 2022 The HuggingFace Inc. team.
# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch
from torch.cuda import nvtx
from collections import OrderedDict
import numpy as np
from polygraphy.backend.common import bytes_from_path
from polygraphy import util
from polygraphy.backend.trt import ModifyNetworkOutputs, Profile
from polygraphy.backend.trt import (
engine_from_bytes,
engine_from_network,
network_from_onnx_path,
save_engine,
)
from polygraphy.logger import G_LOGGER
import tensorrt as trt
from logging import error, warning
from tqdm import tqdm
import copy
TRT_LOGGER = trt.Logger(trt.Logger.ERROR)
G_LOGGER.module_severity = G_LOGGER.ERROR
# Map of numpy dtype -> torch dtype
numpy_to_torch_dtype_dict = {
np.uint8: torch.uint8,
np.int8: torch.int8,
np.int16: torch.int16,
np.int32: torch.int32,
np.int64: torch.int64,
np.float16: torch.float16,
np.float32: torch.float32,
np.float64: torch.float64,
np.complex64: torch.complex64,
np.complex128: torch.complex128,
}
if np.version.full_version >= "1.24.0":
numpy_to_torch_dtype_dict[np.bool_] = torch.bool
else:
numpy_to_torch_dtype_dict[np.bool] = torch.bool
# Map of torch dtype -> numpy dtype
torch_to_numpy_dtype_dict = {
value: key for (key, value) in numpy_to_torch_dtype_dict.items()
}
class TQDMProgressMonitor(trt.IProgressMonitor):
def __init__(self):
trt.IProgressMonitor.__init__(self)
self._active_phases = {}
self._step_result = True
self.max_indent = 5
def phase_start(self, phase_name, parent_phase, num_steps):
leave = False
try:
if parent_phase is not None:
nbIndents = (
self._active_phases.get(parent_phase, {}).get(
"nbIndents", self.max_indent
)
+ 1
)
if nbIndents >= self.max_indent:
return
else:
nbIndents = 0
leave = True
self._active_phases[phase_name] = {
"tq": tqdm(
total=num_steps, desc=phase_name, leave=leave, position=nbIndents
),
"nbIndents": nbIndents,
"parent_phase": parent_phase,
}
except KeyboardInterrupt:
# The phase_start callback cannot directly cancel the build, so request the cancellation from within step_complete.
_step_result = False
def phase_finish(self, phase_name):
try:
if phase_name in self._active_phases.keys():
self._active_phases[phase_name]["tq"].update(
self._active_phases[phase_name]["tq"].total
- self._active_phases[phase_name]["tq"].n
)
parent_phase = self._active_phases[phase_name].get("parent_phase", None)
while parent_phase is not None:
self._active_phases[parent_phase]["tq"].refresh()
parent_phase = self._active_phases[parent_phase].get(
"parent_phase", None
)
if (
self._active_phases[phase_name]["parent_phase"]
in self._active_phases.keys()
):
self._active_phases[
self._active_phases[phase_name]["parent_phase"]
]["tq"].refresh()
del self._active_phases[phase_name]
pass
except KeyboardInterrupt:
_step_result = False
def step_complete(self, phase_name, step):
try:
if phase_name in self._active_phases.keys():
self._active_phases[phase_name]["tq"].update(
step - self._active_phases[phase_name]["tq"].n
)
return self._step_result
except KeyboardInterrupt:
# There is no need to propagate this exception to TensorRT. We can simply cancel the build.
return False
class Engine:
def __init__(
self,
engine_path,
):
self.engine_path = engine_path
self.engine = None
self.context = None
self.buffers = OrderedDict()
self.tensors = OrderedDict()
self.cuda_graph_instance = None # cuda graph
def __del__(self):
del self.engine
del self.context
del self.buffers
del self.tensors
def reset(self, engine_path=None):
del self.engine
del self.context
del self.buffers
del self.tensors
self.engine_path = engine_path
self.buffers = OrderedDict()
self.tensors = OrderedDict()
self.inputs = {}
self.outputs = {}
def refit_from_dict(self, refit_weights, is_fp16):
# Initialize refitter
refitter = trt.Refitter(self.engine, TRT_LOGGER)
refitted_weights = set()
# iterate through all tensorrt refittable weights
for trt_weight_name in refitter.get_all_weights():
if trt_weight_name not in refit_weights:
continue
# get weight from state dict
trt_datatype = trt.DataType.FLOAT
if is_fp16:
refit_weights[trt_weight_name] = refit_weights[trt_weight_name].half()
trt_datatype = trt.DataType.HALF
# trt.Weight and trt.TensorLocation
refit_weights[trt_weight_name] = refit_weights[trt_weight_name].cpu()
trt_wt_tensor = trt.Weights(
trt_datatype,
refit_weights[trt_weight_name].data_ptr(),
torch.numel(refit_weights[trt_weight_name]),
)
trt_wt_location = (
trt.TensorLocation.DEVICE
if refit_weights[trt_weight_name].is_cuda
else trt.TensorLocation.HOST
)
# apply refit
# refitter.set_named_weights(trt_weight_name, trt_wt_tensor, trt_wt_location)
refitter.set_named_weights(trt_weight_name, trt_wt_tensor)
refitted_weights.add(trt_weight_name)
assert set(refitted_weights) == set(refit_weights.keys())
if not refitter.refit_cuda_engine():
print("Error: failed to refit new weights.")
exit(0)
def build(
self,
onnx_path,
fp16,
input_profile=None,
enable_refit=False,
enable_preview=False,
enable_all_tactics=False,
timing_cache=None,
update_output_names=None,
):
print(f"Building TensorRT engine for {onnx_path}: {self.engine_path}")
p = [Profile()]
if input_profile:
p = [Profile() for i in range(len(input_profile))]
for _p, i_profile in zip(p, input_profile):
for name, dims in i_profile.items():
assert len(dims) == 3
_p.add(name, min=dims[0], opt=dims[1], max=dims[2])
config_kwargs = {}
if not enable_all_tactics:
config_kwargs["tactic_sources"] = []
network = network_from_onnx_path(
onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]
)
if update_output_names:
print(f"Updating network outputs to {update_output_names}")
network = ModifyNetworkOutputs(network, update_output_names)
builder = network[0]
config = builder.create_builder_config()
config.progress_monitor = TQDMProgressMonitor()
config.set_flag(trt.BuilderFlag.FP16) if fp16 else None
config.set_flag(trt.BuilderFlag.REFIT) if enable_refit else None
cache = None
try:
with util.LockFile(timing_cache):
timing_cache_data = util.load_file(
timing_cache, description="tactic timing cache"
)
cache = config.create_timing_cache(timing_cache_data)
except FileNotFoundError:
warning(
"Timing cache file {} not found, falling back to empty timing cache.".format(
timing_cache
)
)
if cache is not None:
config.set_timing_cache(cache, ignore_mismatch=True)
profiles = copy.deepcopy(p)
for profile in profiles:
# Last profile is used for set_calibration_profile.
calib_profile = profile.fill_defaults(network[1]).to_trt(
builder, network[1]
)
config.add_optimization_profile(calib_profile)
try:
engine = engine_from_network(
network,
config,
save_timing_cache=timing_cache,
)
except Exception as e:
error(f"Failed to build engine: {e}")
return 1
try:
save_engine(engine, path=self.engine_path)
except Exception as e:
error(f"Failed to save engine: {e}")
return 1
return 0
def load(self):
print(f"Loading TensorRT engine: {self.engine_path}")
self.engine = engine_from_bytes(bytes_from_path(self.engine_path))
def activate(self, reuse_device_memory=None):
if reuse_device_memory:
self.context = self.engine.create_execution_context_without_device_memory()
# self.context.device_memory = reuse_device_memory
else:
self.context = self.engine.create_execution_context()
def allocate_buffers(self, shape_dict=None, device="cuda"):
nvtx.range_push("allocate_buffers")
for idx in range(self.engine.num_io_tensors):
binding = self.engine[idx]
if shape_dict and binding in shape_dict:
shape = shape_dict[binding].shape
else:
shape = self.context.get_binding_shape(idx)
dtype = trt.nptype(self.engine.get_binding_dtype(binding))
if self.engine.binding_is_input(binding):
self.context.set_binding_shape(idx, shape)
tensor = torch.empty(
tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype]
).to(device=device)
self.tensors[binding] = tensor
nvtx.range_pop()
def infer(self, feed_dict, stream, use_cuda_graph=False):
nvtx.range_push("set_tensors")
for name, buf in feed_dict.items():
self.tensors[name].copy_(buf)
for name, tensor in self.tensors.items():
self.context.set_tensor_address(name, tensor.data_ptr())
nvtx.range_pop()
nvtx.range_push("execute")
noerror = self.context.execute_async_v3(stream)
if not noerror:
raise ValueError("ERROR: inference failed.")
nvtx.range_pop()
return self.tensors
def __str__(self):
out = ""
for opt_profile in range(self.engine.num_optimization_profiles):
for binding_idx in range(self.engine.num_bindings):
name = self.engine.get_binding_name(binding_idx)
shape = self.engine.get_profile_shape(opt_profile, name)
out += f"\t{name} = {shape}\n"
return out
|