| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | #include <algorithm> |
| | #include <cctype> |
| | #include <chrono> |
| | #include <cmath> |
| | #include <functional> |
| | #include <iostream> |
| | #include <memory> |
| | #include <sys/stat.h> |
| | #include <vector> |
| |
|
| | #include "NvInfer.h" |
| | #include "NvInferPlugin.h" |
| |
|
| | #include "buffers.h" |
| | #include "common.h" |
| | #include "logger.h" |
| | #include "sampleDevice.h" |
| | #include "sampleEngines.h" |
| | #include "sampleInference.h" |
| | #include "sampleOptions.h" |
| | #include "sampleReporting.h" |
| |
|
| | using namespace nvinfer1; |
| | using namespace sample; |
| | using namespace samplesCommon; |
| |
|
| | #if ENABLE_UNIFIED_BUILDER |
| | using namespace nvinfer2::safe; |
| | __attribute__((weak)) std::shared_ptr<sample::SampleSafeRecorder> gSafeRecorder |
| | = std::make_shared<sample::SampleSafeRecorder>(nvinfer2::safe::Severity::kINFO); |
| | #endif |
| |
|
| | namespace |
| | { |
| | using LibraryPtr = std::unique_ptr<DynamicLibrary>; |
| |
|
| | std::function<void*(void*, int32_t)> pCreateInferRuntimeInternal{}; |
| | std::function<void*(void*, void*, int32_t)> pCreateInferRefitterInternal{}; |
| | std::function<void*(void*, int32_t)> pCreateInferBuilderInternal{}; |
| | std::function<void*(void*, void*, int)> pCreateNvOnnxParserInternal{}; |
| |
|
| | |
| | |
| | RuntimeMode gUseRuntime = RuntimeMode::kFULL; |
| |
|
| | bool initNvinfer() |
| | { |
| | #if !TRT_STATIC |
| | static LibraryPtr libnvinferPtr{}; |
| | auto fetchPtrs = [](DynamicLibrary* l) { |
| | pCreateInferRuntimeInternal = l->symbolAddress<void*(void*, int32_t)>("createInferRuntime_INTERNAL"); |
| | try |
| | { |
| | pCreateInferRefitterInternal |
| | = l->symbolAddress<void*(void*, void*, int32_t)>("createInferRefitter_INTERNAL"); |
| | } |
| | catch (const std::exception& e) |
| | { |
| | sample::gLogWarning << "Could not load function createInferRefitter_INTERNAL : " << e.what() << std::endl; |
| | } |
| |
|
| | if (gUseRuntime == RuntimeMode::kFULL) |
| | { |
| | pCreateInferBuilderInternal = l->symbolAddress<void*(void*, int32_t)>("createInferBuilder_INTERNAL"); |
| | } |
| | }; |
| | return initLibrary(libnvinferPtr, getRuntimeLibraryName(gUseRuntime), fetchPtrs); |
| | #else |
| | pCreateInferRuntimeInternal = createInferRuntime_INTERNAL; |
| | pCreateInferRefitterInternal = createInferRefitter_INTERNAL; |
| | pCreateInferBuilderInternal = createInferBuilder_INTERNAL; |
| | return true; |
| | #endif |
| | } |
| |
|
| | bool initNvonnxparser() |
| | { |
| | #if !TRT_STATIC |
| | static LibraryPtr libnvonnxparserPtr{}; |
| | auto fetchPtrs = [](DynamicLibrary* l) { |
| | pCreateNvOnnxParserInternal = l->symbolAddress<void*(void*, void*, int)>("createNvOnnxParser_INTERNAL"); |
| | }; |
| | return initLibrary(libnvonnxparserPtr, kNVONNXPARSER_LIBNAME, fetchPtrs); |
| | #else |
| | pCreateNvOnnxParserInternal = createNvOnnxParser_INTERNAL; |
| | return true; |
| | #endif |
| | } |
| |
|
| | } |
| |
|
| | IRuntime* createRuntime() |
| | { |
| | if (!initNvinfer()) |
| | { |
| | return {}; |
| | } |
| | ASSERT(pCreateInferRuntimeInternal != nullptr); |
| | return static_cast<IRuntime*>(pCreateInferRuntimeInternal(&gLogger.getTRTLogger(), NV_TENSORRT_VERSION)); |
| | } |
| |
|
| | IBuilder* createBuilder() |
| | { |
| | if (!initNvinfer()) |
| | { |
| | return {}; |
| | } |
| | ASSERT(pCreateInferBuilderInternal != nullptr); |
| | return static_cast<IBuilder*>(pCreateInferBuilderInternal(&gLogger.getTRTLogger(), NV_TENSORRT_VERSION)); |
| | } |
| |
|
| | IRefitter* createRefitter(ICudaEngine& engine) |
| | { |
| | if (!initNvinfer()) |
| | { |
| | return {}; |
| | } |
| | ASSERT(pCreateInferRefitterInternal != nullptr); |
| | return static_cast<IRefitter*>(pCreateInferRefitterInternal(&engine, &gLogger.getTRTLogger(), NV_TENSORRT_VERSION)); |
| | } |
| |
|
| | nvonnxparser::IParser* createONNXParser(INetworkDefinition& network) |
| | { |
| | if (!initNvonnxparser()) |
| | { |
| | return {}; |
| | } |
| | ASSERT(pCreateNvOnnxParserInternal != nullptr); |
| | return static_cast<nvonnxparser::IParser*>( |
| | pCreateNvOnnxParserInternal(&network, &gLogger.getTRTLogger(), NV_ONNX_PARSER_VERSION)); |
| | } |
| |
|
| | #if ENABLE_UNIFIED_BUILDER |
| |
|
| | bool processSafetyPluginLibrary(nvinfer2::safe::ISafePluginRegistry* safetyPluginRegistry, DynamicLibrary* libPtr, |
| | samplesSafeCommon::SafetyPluginLibraryArgument const& pluginArgs) |
| | { |
| | if (libPtr == nullptr) |
| | { |
| | sample::gLogError << "Cannot open safety plugin library " << pluginArgs.libraryName << std::endl; |
| | return false; |
| | } |
| | std::string const pluginGetterSymbolName{"getSafetyPluginCreator"}; |
| | auto pGetSafetyPluginCreator |
| | = libPtr->symbolAddress<void*(char const*, char const*)>(pluginGetterSymbolName.c_str()); |
| | if (pGetSafetyPluginCreator == nullptr) |
| | { |
| | sample::gLogError << "Cannot find plugin creator getter symbol from plugin library: " << pluginArgs.libraryName |
| | << std::endl; |
| | sample::gLogError << "Please ensure interface function is correctly implemented and exported." << std::endl; |
| | return false; |
| | } |
| |
|
| | for (auto const& pluginAttr : pluginArgs.pluginAttrs) |
| | { |
| | auto pluginCreator = static_cast<IPluginCreatorInterface*>( |
| | pGetSafetyPluginCreator(pluginAttr.pluginNamespace.c_str(), pluginAttr.pluginName.c_str())); |
| | if (pluginCreator == nullptr) |
| | { |
| | sample::gLogInfo << "Cannot find plugin " << pluginAttr.pluginNamespace << "::" << pluginAttr.pluginName |
| | << " in the safety plugin library: " << pluginArgs.libraryName << std::endl; |
| | continue; |
| | } |
| | sample::gLogInfo << "Registering " << pluginAttr.pluginNamespace << "::" << pluginAttr.pluginName |
| | << " for TensorRT safety." << std::endl; |
| | safetyPluginRegistry->registerCreator(*pluginCreator, pluginAttr.pluginNamespace.c_str(), *gSafeRecorder); |
| | } |
| | return true; |
| | } |
| | #endif |
| |
|
| | using time_point = std::chrono::time_point<std::chrono::high_resolution_clock>; |
| | using duration = std::chrono::duration<float>; |
| |
|
| | int main(int argc, char** argv) |
| | { |
| | std::string const sampleName = "TensorRT.trtexec"; |
| |
|
| | auto sampleTest = sample::gLogger.defineTest(sampleName, argc, argv); |
| |
|
| | try |
| | { |
| | sample::gLogger.reportTestStart(sampleTest); |
| |
|
| | Arguments args = argsToArgumentsMap(argc, argv); |
| | AllOptions options; |
| |
|
| | if (parseHelp(args)) |
| | { |
| | AllOptions::help(std::cout); |
| | return EXIT_SUCCESS; |
| | } |
| |
|
| | if (!args.empty()) |
| | { |
| | bool failed{false}; |
| | try |
| | { |
| | options.parse(args); |
| |
|
| | if (!args.empty()) |
| | { |
| | AllOptions::help(std::cout); |
| | for (auto const& arg : args) |
| | { |
| | sample::gLogError << "Unknown option: " << arg.first << " " << arg.second.first << std::endl; |
| | } |
| | failed = true; |
| | } |
| | } |
| | catch (std::invalid_argument const& arg) |
| | { |
| | AllOptions::help(std::cout); |
| | sample::gLogError << arg.what() << std::endl; |
| | failed = true; |
| | } |
| |
|
| | if (failed) |
| | { |
| | return sample::gLogger.reportFail(sampleTest); |
| | } |
| | } |
| | else |
| | { |
| | options.helps = true; |
| | } |
| |
|
| | if (options.helps) |
| | { |
| | AllOptions::help(std::cout); |
| | return sample::gLogger.reportPass(sampleTest); |
| | } |
| |
|
| | sample::gLogInfo << options; |
| | if (options.reporting.verbose) |
| | { |
| | sample::setReportableSeverity(ILogger::Severity::kVERBOSE); |
| | } |
| | std::string const jitInVersion; |
| | setCudaDevice(options.system.device, sample::gLogInfo); |
| | sample::gLogInfo << std::endl; |
| | sample::gLogInfo << "TensorRT version: " << NV_TENSORRT_MAJOR << "." << NV_TENSORRT_MINOR << "." |
| | << NV_TENSORRT_PATCH << jitInVersion << std::endl; |
| |
|
| | |
| | gUseRuntime = options.build.useRuntime; |
| | #if !TRT_STATIC |
| | LibraryPtr nvinferPluginLib{}; |
| | #endif |
| | std::vector<LibraryPtr> pluginLibs; |
| | if (gUseRuntime == RuntimeMode::kFULL) |
| | { |
| | sample::gLogInfo << "Loading standard plugins" << std::endl; |
| | #if !TRT_STATIC |
| | nvinferPluginLib = loadLibrary(kNVINFER_PLUGIN_LIBNAME); |
| | auto pInitLibNvinferPlugins |
| | = nvinferPluginLib->symbolAddress<bool(void*, char const*)>("initLibNvInferPlugins"); |
| | #else |
| | auto pInitLibNvinferPlugins = initLibNvInferPlugins; |
| | #endif |
| | ASSERT(pInitLibNvinferPlugins != nullptr); |
| | pInitLibNvinferPlugins(&sample::gLogger.getTRTLogger(), ""); |
| | for (auto const& pluginPath : options.system.plugins) |
| | { |
| | sample::gLogInfo << "Loading supplied plugin library: " << pluginPath << std::endl; |
| | pluginLibs.emplace_back(loadLibrary(pluginPath)); |
| | } |
| | } |
| | else if (!options.system.plugins.empty()) |
| | { |
| | throw std::runtime_error("TRT-18412: Plugins require --useRuntime=full."); |
| | } |
| | #if ENABLE_UNIFIED_BUILDER |
| | auto safetyPluginRegistry = sample::safe::getSafePluginRegistry(*gSafeRecorder); |
| | ASSERT(safetyPluginRegistry != nullptr); |
| |
|
| | if (!options.system.safetyPlugins.empty()) |
| | { |
| | for (auto const& safetyPluginArg : options.system.safetyPlugins) |
| | { |
| | sample::gLogInfo << "Loading supplied safety plugin library with manual registration: " |
| | << safetyPluginArg.libraryName << std::endl; |
| | auto pluginLib = loadLibrary(safetyPluginArg.libraryName); |
| | processSafetyPluginLibrary(safetyPluginRegistry, pluginLib.get(), safetyPluginArg); |
| | pluginLibs.emplace_back(std::move(pluginLib)); |
| | } |
| | } |
| | #endif |
| | if (options.build.safe && !sample::hasSafeRuntime()) |
| | { |
| | sample::gLogError << "Safety is not supported because safety runtime library is unavailable." << std::endl; |
| | return sample::gLogger.reportFail(sampleTest); |
| | } |
| |
|
| | if (!options.build.safe && options.build.consistency) |
| | { |
| | sample::gLogInfo << "Skipping consistency checker on non-safety mode." << std::endl; |
| | options.build.consistency = false; |
| | } |
| |
|
| | |
| | std::unique_ptr<BuildEnvironment> bEnv(new BuildEnvironment(options.build.safe, options.build.versionCompatible, |
| | options.system.DLACore, options.build.tempdir, options.build.tempfileControls, options.build.leanDLLPath, |
| | sampleTest.getCmdline())); |
| |
|
| | bool buildPass = getEngineBuildEnv(options.model, options.build, options.system, *bEnv, sample::gLogError); |
| |
|
| | if (!buildPass) |
| | { |
| | sample::gLogError << "Engine set up failed" << std::endl; |
| | return sample::gLogger.reportFail(sampleTest); |
| | } |
| |
|
| | #if ENABLE_UNIFIED_BUILDER |
| | safetyPluginRegistry->setSafeRecorder(*gSafeRecorder); |
| | #endif |
| |
|
| | |
| | if (options.build.getPlanVersionOnly) |
| | { |
| | return sample::gLogger.reportPass(sampleTest); |
| | } |
| |
|
| |
|
| | |
| | bEnv->engine.setDynamicPlugins(options.system.dynamicPlugins); |
| | |
| | |
| | bool const supportDeserialization = !options.build.safe && !options.build.buildDLAStandalone |
| | && options.build.runtimePlatform == nvinfer1::RuntimePlatform::kSAME_AS_BUILD; |
| |
|
| | if (supportDeserialization && options.build.refittable) |
| | { |
| | auto* engine = bEnv->engine.get(); |
| | if (options.reporting.refit) |
| | { |
| | dumpRefittable(*engine); |
| | } |
| | if (options.inference.timeRefit) |
| | { |
| | if (bEnv->network.operator bool()) |
| | { |
| | bool const success = timeRefit(*bEnv->network, *engine, options.inference.threads); |
| | if (!success) |
| | { |
| | sample::gLogError << "Engine refit failed." << std::endl; |
| | return sample::gLogger.reportFail(sampleTest); |
| | } |
| | } |
| | else |
| | { |
| | sample::gLogWarning << "Network not available, skipped timing refit." << std::endl; |
| | } |
| | } |
| | } |
| |
|
| | if (options.build.skipInference) |
| | { |
| | if (supportDeserialization) |
| | { |
| | printLayerInfo(options.reporting, bEnv->engine.get(), nullptr); |
| | printOptimizationProfileInfo(options.reporting, bEnv->engine.get()); |
| | } |
| | sample::gLogInfo << "Skipped inference phase since --skipInference is added." << std::endl; |
| | return sample::gLogger.reportPass(sampleTest); |
| | } |
| |
|
| | std::unique_ptr<InferenceEnvironmentBase> iEnv; |
| |
|
| | if (!options.build.safe) |
| | { |
| | iEnv = std::make_unique<InferenceEnvironmentStd>(*bEnv); |
| | } |
| | else |
| | { |
| | #if ENABLE_UNIFIED_BUILDER |
| | iEnv = std::make_unique<InferenceEnvironmentSafe>(*bEnv); |
| | #else |
| | sample::gLogInfo << "--safe flag is enabled but application is not compatible with safety." << std::endl; |
| | return sample::gLogger.reportFail(sampleTest); |
| | #endif |
| | } |
| |
|
| | |
| | |
| | std::vector<std::string> dynamicPluginsNotSerialized; |
| | for (auto& pluginName : options.system.dynamicPlugins) |
| | { |
| | if (std::find(options.system.setPluginsToSerialize.begin(), options.system.setPluginsToSerialize.end(), |
| | pluginName) |
| | == options.system.setPluginsToSerialize.end()) |
| | { |
| | dynamicPluginsNotSerialized.emplace_back(pluginName); |
| | } |
| | } |
| |
|
| | iEnv->engine.setDynamicPlugins(dynamicPluginsNotSerialized); |
| | |
| | bEnv.reset(); |
| |
|
| | if (options.inference.timeDeserialize) |
| | { |
| | if (timeDeserialize(*iEnv, options.system)) |
| | { |
| | return sample::gLogger.reportFail(sampleTest); |
| | } |
| | return sample::gLogger.reportPass(sampleTest); |
| | } |
| | if (options.build.safe && options.system.DLACore >= 0) |
| | { |
| | sample::gLogInfo << "Safe DLA capability is detected. Please save DLA loadable with --saveEngine option, " |
| | "then use dla_safety_runtime to run inference with saved DLA loadable, " |
| | "or alternatively run with your own application" |
| | << std::endl; |
| | return sample::gLogger.reportFail(sampleTest); |
| | } |
| | bool const profilerEnabled = options.reporting.profile || !options.reporting.exportProfile.empty(); |
| |
|
| | bool const layerInfoEnabled = options.reporting.layerInfo || !options.reporting.exportLayerInfo.empty(); |
| | if (iEnv->safe && (profilerEnabled || layerInfoEnabled)) |
| | { |
| | sample::gLogError << "Safe runtime does not support --dumpProfile or --exportProfile=<file> or " |
| | "--dumpLayerInfo or --exportLayerInfo=<file>, please use " |
| | "--verbose to print profiling info." |
| | << std::endl; |
| | return sample::gLogger.reportFail(sampleTest); |
| | } |
| | if (profilerEnabled && !options.inference.rerun) |
| | { |
| | iEnv->profiler.reset(new Profiler); |
| | if (options.inference.graph && (getCudaDriverVersion() < 11010 || getCudaRuntimeVersion() < 11000)) |
| | { |
| | options.inference.graph = false; |
| | sample::gLogWarning |
| | << "Graph profiling only works with CUDA 11.1 and beyond. Ignored --useCudaGraph flag " |
| | "and disabled CUDA graph." |
| | << std::endl; |
| | } |
| | } |
| |
|
| | if (!setUpInference(*iEnv, options.inference, options.system)) |
| | { |
| | sample::gLogError << "Inference set up failed" << std::endl; |
| | return sample::gLogger.reportFail(sampleTest); |
| | } |
| |
|
| | if (!options.build.safe) |
| | { |
| | printLayerInfo(options.reporting, iEnv->engine.get(), |
| | static_cast<InferenceEnvironmentStd*>(iEnv.get())->contexts.front().get()); |
| | printOptimizationProfileInfo(options.reporting, iEnv->engine.get()); |
| | } |
| | std::vector<InferenceTrace> trace; |
| | sample::gLogInfo << "Starting inference" << std::endl; |
| |
|
| | if (!runInference(options.inference, *iEnv, options.system.device, trace, options.reporting)) |
| | { |
| | sample::gLogError << "Error occurred during inference" << std::endl; |
| | return sample::gLogger.reportFail(sampleTest); |
| | } |
| |
|
| | if (profilerEnabled && !options.inference.rerun) |
| | { |
| | sample::gLogInfo << "The e2e network timing is not reported since it is inaccurate due to the extra " |
| | << "synchronizations when the profiler is enabled." << std::endl; |
| | sample::gLogInfo |
| | << "To show e2e network timing report, add --separateProfileRun to profile layer timing in a " |
| | << "separate run or remove --dumpProfile to disable the profiler." << std::endl; |
| | } |
| | else |
| | { |
| | printPerformanceReport(trace, options.reporting, options.inference, sample::gLogInfo, sample::gLogWarning, |
| | sample::gLogVerbose); |
| | } |
| |
|
| | printOutput(options.reporting, *iEnv, options.inference.batch); |
| |
|
| | if (profilerEnabled && options.inference.rerun) |
| | { |
| | auto* profiler = new Profiler; |
| | iEnv->profiler.reset(profiler); |
| | static_cast<InferenceEnvironmentStd*>(iEnv.get())->contexts.front()->setProfiler(profiler); |
| | static_cast<InferenceEnvironmentStd*>(iEnv.get())->contexts.front()->setEnqueueEmitsProfile(false); |
| | if (options.inference.graph && (getCudaDriverVersion() < 11010 || getCudaRuntimeVersion() < 11000)) |
| | { |
| | options.inference.graph = false; |
| | sample::gLogWarning |
| | << "Graph profiling only works with CUDA 11.1 and beyond. Ignored --useCudaGraph flag " |
| | "and disabled CUDA graph." |
| | << std::endl; |
| | } |
| | if (!runInference(options.inference, *iEnv, options.system.device, trace, options.reporting)) |
| | { |
| | sample::gLogError << "Error occurred during inference" << std::endl; |
| | return sample::gLogger.reportFail(sampleTest); |
| | } |
| | } |
| | printPerformanceProfile(options.reporting, *iEnv); |
| |
|
| | return sample::gLogger.reportPass(sampleTest); |
| | } |
| | catch (std::exception const& e) |
| | { |
| | sample::gLogError << "Uncaught exception detected: " << e.what() << std::endl; |
| | } |
| | return sample::gLogger.reportFail(sampleTest); |
| | } |
| |
|