/* * SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include #include #include #include #include #include #include #include "NvInfer.h" #include "NvInferPlugin.h" #include "buffers.h" #include "common.h" #include "logger.h" #include "sampleDevice.h" #include "sampleEngines.h" #include "sampleInference.h" #include "sampleOptions.h" #include "sampleReporting.h" using namespace nvinfer1; using namespace sample; using namespace samplesCommon; #if ENABLE_UNIFIED_BUILDER using namespace nvinfer2::safe; __attribute__((weak)) std::shared_ptr gSafeRecorder = std::make_shared(nvinfer2::safe::Severity::kINFO); #endif namespace { using LibraryPtr = std::unique_ptr; std::function pCreateInferRuntimeInternal{}; std::function pCreateInferRefitterInternal{}; std::function pCreateInferBuilderInternal{}; std::function pCreateNvOnnxParserInternal{}; //! Track runtime used for the execution of trtexec. //! Must be tracked as a global variable due to how library init functions APIs are organized. RuntimeMode gUseRuntime = RuntimeMode::kFULL; bool initNvinfer() { #if !TRT_STATIC static LibraryPtr libnvinferPtr{}; auto fetchPtrs = [](DynamicLibrary* l) { pCreateInferRuntimeInternal = l->symbolAddress("createInferRuntime_INTERNAL"); try { pCreateInferRefitterInternal = l->symbolAddress("createInferRefitter_INTERNAL"); } catch (const std::exception& e) { sample::gLogWarning << "Could not load function createInferRefitter_INTERNAL : " << e.what() << std::endl; } if (gUseRuntime == RuntimeMode::kFULL) { pCreateInferBuilderInternal = l->symbolAddress("createInferBuilder_INTERNAL"); } }; return initLibrary(libnvinferPtr, getRuntimeLibraryName(gUseRuntime), fetchPtrs); #else pCreateInferRuntimeInternal = createInferRuntime_INTERNAL; pCreateInferRefitterInternal = createInferRefitter_INTERNAL; pCreateInferBuilderInternal = createInferBuilder_INTERNAL; return true; #endif // !TRT_STATIC } bool initNvonnxparser() { #if !TRT_STATIC static LibraryPtr libnvonnxparserPtr{}; auto fetchPtrs = [](DynamicLibrary* l) { pCreateNvOnnxParserInternal = l->symbolAddress("createNvOnnxParser_INTERNAL"); }; return initLibrary(libnvonnxparserPtr, kNVONNXPARSER_LIBNAME, fetchPtrs); #else pCreateNvOnnxParserInternal = createNvOnnxParser_INTERNAL; return true; #endif // !TRT_STATIC } } // namespace IRuntime* createRuntime() { if (!initNvinfer()) { return {}; } ASSERT(pCreateInferRuntimeInternal != nullptr); return static_cast(pCreateInferRuntimeInternal(&gLogger.getTRTLogger(), NV_TENSORRT_VERSION)); } IBuilder* createBuilder() { if (!initNvinfer()) { return {}; } ASSERT(pCreateInferBuilderInternal != nullptr); return static_cast(pCreateInferBuilderInternal(&gLogger.getTRTLogger(), NV_TENSORRT_VERSION)); } IRefitter* createRefitter(ICudaEngine& engine) { if (!initNvinfer()) { return {}; } ASSERT(pCreateInferRefitterInternal != nullptr); return static_cast(pCreateInferRefitterInternal(&engine, &gLogger.getTRTLogger(), NV_TENSORRT_VERSION)); } nvonnxparser::IParser* createONNXParser(INetworkDefinition& network) { if (!initNvonnxparser()) { return {}; } ASSERT(pCreateNvOnnxParserInternal != nullptr); return static_cast( pCreateNvOnnxParserInternal(&network, &gLogger.getTRTLogger(), NV_ONNX_PARSER_VERSION)); } #if ENABLE_UNIFIED_BUILDER bool processSafetyPluginLibrary(nvinfer2::safe::ISafePluginRegistry* safetyPluginRegistry, DynamicLibrary* libPtr, samplesSafeCommon::SafetyPluginLibraryArgument const& pluginArgs) { if (libPtr == nullptr) { sample::gLogError << "Cannot open safety plugin library " << pluginArgs.libraryName << std::endl; return false; } std::string const pluginGetterSymbolName{"getSafetyPluginCreator"}; auto pGetSafetyPluginCreator = libPtr->symbolAddress(pluginGetterSymbolName.c_str()); if (pGetSafetyPluginCreator == nullptr) { sample::gLogError << "Cannot find plugin creator getter symbol from plugin library: " << pluginArgs.libraryName << std::endl; sample::gLogError << "Please ensure interface function is correctly implemented and exported." << std::endl; return false; } for (auto const& pluginAttr : pluginArgs.pluginAttrs) { auto pluginCreator = static_cast( pGetSafetyPluginCreator(pluginAttr.pluginNamespace.c_str(), pluginAttr.pluginName.c_str())); if (pluginCreator == nullptr) { sample::gLogInfo << "Cannot find plugin " << pluginAttr.pluginNamespace << "::" << pluginAttr.pluginName << " in the safety plugin library: " << pluginArgs.libraryName << std::endl; continue; } sample::gLogInfo << "Registering " << pluginAttr.pluginNamespace << "::" << pluginAttr.pluginName << " for TensorRT safety." << std::endl; safetyPluginRegistry->registerCreator(*pluginCreator, pluginAttr.pluginNamespace.c_str(), *gSafeRecorder); } return true; } #endif using time_point = std::chrono::time_point; using duration = std::chrono::duration; int main(int argc, char** argv) { std::string const sampleName = "TensorRT.trtexec"; auto sampleTest = sample::gLogger.defineTest(sampleName, argc, argv); try { sample::gLogger.reportTestStart(sampleTest); Arguments args = argsToArgumentsMap(argc, argv); AllOptions options; if (parseHelp(args)) { AllOptions::help(std::cout); return EXIT_SUCCESS; } if (!args.empty()) { bool failed{false}; try { options.parse(args); if (!args.empty()) { AllOptions::help(std::cout); for (auto const& arg : args) { sample::gLogError << "Unknown option: " << arg.first << " " << arg.second.first << std::endl; } failed = true; } } catch (std::invalid_argument const& arg) { AllOptions::help(std::cout); sample::gLogError << arg.what() << std::endl; failed = true; } if (failed) { return sample::gLogger.reportFail(sampleTest); } } else { options.helps = true; } if (options.helps) { AllOptions::help(std::cout); return sample::gLogger.reportPass(sampleTest); } sample::gLogInfo << options; if (options.reporting.verbose) { sample::setReportableSeverity(ILogger::Severity::kVERBOSE); } std::string const jitInVersion; setCudaDevice(options.system.device, sample::gLogInfo); sample::gLogInfo << std::endl; sample::gLogInfo << "TensorRT version: " << NV_TENSORRT_MAJOR << "." << NV_TENSORRT_MINOR << "." << NV_TENSORRT_PATCH << jitInVersion << std::endl; // Record specified runtime gUseRuntime = options.build.useRuntime; #if !TRT_STATIC LibraryPtr nvinferPluginLib{}; #endif /* TRT_STATIC */ std::vector pluginLibs; if (gUseRuntime == RuntimeMode::kFULL) { sample::gLogInfo << "Loading standard plugins" << std::endl; #if !TRT_STATIC nvinferPluginLib = loadLibrary(kNVINFER_PLUGIN_LIBNAME); auto pInitLibNvinferPlugins = nvinferPluginLib->symbolAddress("initLibNvInferPlugins"); #else /* TRT_STATIC */ auto pInitLibNvinferPlugins = initLibNvInferPlugins; #endif /* TRT_STATIC */ ASSERT(pInitLibNvinferPlugins != nullptr); pInitLibNvinferPlugins(&sample::gLogger.getTRTLogger(), ""); for (auto const& pluginPath : options.system.plugins) { sample::gLogInfo << "Loading supplied plugin library: " << pluginPath << std::endl; pluginLibs.emplace_back(loadLibrary(pluginPath)); } } else if (!options.system.plugins.empty()) { throw std::runtime_error("TRT-18412: Plugins require --useRuntime=full."); } #if ENABLE_UNIFIED_BUILDER auto safetyPluginRegistry = sample::safe::getSafePluginRegistry(*gSafeRecorder); ASSERT(safetyPluginRegistry != nullptr); if (!options.system.safetyPlugins.empty()) { for (auto const& safetyPluginArg : options.system.safetyPlugins) { sample::gLogInfo << "Loading supplied safety plugin library with manual registration: " << safetyPluginArg.libraryName << std::endl; auto pluginLib = loadLibrary(safetyPluginArg.libraryName); processSafetyPluginLibrary(safetyPluginRegistry, pluginLib.get(), safetyPluginArg); pluginLibs.emplace_back(std::move(pluginLib)); } } #endif // ENABLE_UNIFIED_BUILDER if (options.build.safe && !sample::hasSafeRuntime()) { sample::gLogError << "Safety is not supported because safety runtime library is unavailable." << std::endl; return sample::gLogger.reportFail(sampleTest); } if (!options.build.safe && options.build.consistency) { sample::gLogInfo << "Skipping consistency checker on non-safety mode." << std::endl; options.build.consistency = false; } // Start engine building phase. std::unique_ptr bEnv(new BuildEnvironment(options.build.safe, options.build.versionCompatible, options.system.DLACore, options.build.tempdir, options.build.tempfileControls, options.build.leanDLLPath, sampleTest.getCmdline())); bool buildPass = getEngineBuildEnv(options.model, options.build, options.system, *bEnv, sample::gLogError); if (!buildPass) { sample::gLogError << "Engine set up failed" << std::endl; return sample::gLogger.reportFail(sampleTest); } #if ENABLE_UNIFIED_BUILDER safetyPluginRegistry->setSafeRecorder(*gSafeRecorder); #endif // ENABLE_UNIFIED_BUILDER // Exit as version is already printed during getEngineBuildEnv if (options.build.getPlanVersionOnly) { return sample::gLogger.reportPass(sampleTest); } // dynamicPlugins may have been updated by getEngineBuildEnv above bEnv->engine.setDynamicPlugins(options.system.dynamicPlugins); // When some options are enabled, engine deserialization is not supported on the platform that the engine was // built. bool const supportDeserialization = !options.build.safe && !options.build.buildDLAStandalone && options.build.runtimePlatform == nvinfer1::RuntimePlatform::kSAME_AS_BUILD; if (supportDeserialization && options.build.refittable) { auto* engine = bEnv->engine.get(); if (options.reporting.refit) { dumpRefittable(*engine); } if (options.inference.timeRefit) { if (bEnv->network.operator bool()) { bool const success = timeRefit(*bEnv->network, *engine, options.inference.threads); if (!success) { sample::gLogError << "Engine refit failed." << std::endl; return sample::gLogger.reportFail(sampleTest); } } else { sample::gLogWarning << "Network not available, skipped timing refit." << std::endl; } } } if (options.build.skipInference) { if (supportDeserialization) { printLayerInfo(options.reporting, bEnv->engine.get(), nullptr); printOptimizationProfileInfo(options.reporting, bEnv->engine.get()); } sample::gLogInfo << "Skipped inference phase since --skipInference is added." << std::endl; return sample::gLogger.reportPass(sampleTest); } std::unique_ptr iEnv; if (!options.build.safe) { iEnv = std::make_unique(*bEnv); } else { #if ENABLE_UNIFIED_BUILDER iEnv = std::make_unique(*bEnv); #else sample::gLogInfo << "--safe flag is enabled but application is not compatible with safety." << std::endl; return sample::gLogger.reportFail(sampleTest); #endif } // We avoid re-loading some dynamic plugins while deserializing // if they were already serialized with `setPluginsToSerialize`. std::vector dynamicPluginsNotSerialized; for (auto& pluginName : options.system.dynamicPlugins) { if (std::find(options.system.setPluginsToSerialize.begin(), options.system.setPluginsToSerialize.end(), pluginName) == options.system.setPluginsToSerialize.end()) { dynamicPluginsNotSerialized.emplace_back(pluginName); } } iEnv->engine.setDynamicPlugins(dynamicPluginsNotSerialized); // Delete build environment. bEnv.reset(); if (options.inference.timeDeserialize) { if (timeDeserialize(*iEnv, options.system)) { return sample::gLogger.reportFail(sampleTest); } return sample::gLogger.reportPass(sampleTest); } if (options.build.safe && options.system.DLACore >= 0) { sample::gLogInfo << "Safe DLA capability is detected. Please save DLA loadable with --saveEngine option, " "then use dla_safety_runtime to run inference with saved DLA loadable, " "or alternatively run with your own application" << std::endl; return sample::gLogger.reportFail(sampleTest); } bool const profilerEnabled = options.reporting.profile || !options.reporting.exportProfile.empty(); bool const layerInfoEnabled = options.reporting.layerInfo || !options.reporting.exportLayerInfo.empty(); if (iEnv->safe && (profilerEnabled || layerInfoEnabled)) { sample::gLogError << "Safe runtime does not support --dumpProfile or --exportProfile= or " "--dumpLayerInfo or --exportLayerInfo=, please use " "--verbose to print profiling info." << std::endl; return sample::gLogger.reportFail(sampleTest); } if (profilerEnabled && !options.inference.rerun) { iEnv->profiler.reset(new Profiler); if (options.inference.graph && (getCudaDriverVersion() < 11010 || getCudaRuntimeVersion() < 11000)) { options.inference.graph = false; sample::gLogWarning << "Graph profiling only works with CUDA 11.1 and beyond. Ignored --useCudaGraph flag " "and disabled CUDA graph." << std::endl; } } if (!setUpInference(*iEnv, options.inference, options.system)) { sample::gLogError << "Inference set up failed" << std::endl; return sample::gLogger.reportFail(sampleTest); } if (!options.build.safe) { printLayerInfo(options.reporting, iEnv->engine.get(), static_cast(iEnv.get())->contexts.front().get()); printOptimizationProfileInfo(options.reporting, iEnv->engine.get()); } std::vector trace; sample::gLogInfo << "Starting inference" << std::endl; if (!runInference(options.inference, *iEnv, options.system.device, trace, options.reporting)) { sample::gLogError << "Error occurred during inference" << std::endl; return sample::gLogger.reportFail(sampleTest); } if (profilerEnabled && !options.inference.rerun) { sample::gLogInfo << "The e2e network timing is not reported since it is inaccurate due to the extra " << "synchronizations when the profiler is enabled." << std::endl; sample::gLogInfo << "To show e2e network timing report, add --separateProfileRun to profile layer timing in a " << "separate run or remove --dumpProfile to disable the profiler." << std::endl; } else { printPerformanceReport(trace, options.reporting, options.inference, sample::gLogInfo, sample::gLogWarning, sample::gLogVerbose); } printOutput(options.reporting, *iEnv, options.inference.batch); if (profilerEnabled && options.inference.rerun) { auto* profiler = new Profiler; iEnv->profiler.reset(profiler); static_cast(iEnv.get())->contexts.front()->setProfiler(profiler); static_cast(iEnv.get())->contexts.front()->setEnqueueEmitsProfile(false); if (options.inference.graph && (getCudaDriverVersion() < 11010 || getCudaRuntimeVersion() < 11000)) { options.inference.graph = false; sample::gLogWarning << "Graph profiling only works with CUDA 11.1 and beyond. Ignored --useCudaGraph flag " "and disabled CUDA graph." << std::endl; } if (!runInference(options.inference, *iEnv, options.system.device, trace, options.reporting)) { sample::gLogError << "Error occurred during inference" << std::endl; return sample::gLogger.reportFail(sampleTest); } } printPerformanceProfile(options.reporting, *iEnv); return sample::gLogger.reportPass(sampleTest); } catch (std::exception const& e) { sample::gLogError << "Uncaught exception detected: " << e.what() << std::endl; } return sample::gLogger.reportFail(sampleTest); }