// Copyright (c) Meta Platforms, Inc. and affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // #include "model_container.h" #include "device_functions-generated.h" #include "raii_wrapper.h" namespace { std::string GetEnumString(AITemplateDtype dtype) { switch (dtype) { case AITemplateDtype::kUnset: return "kUnset"; case AITemplateDtype::kHalf: return "kHalf"; case AITemplateDtype::kFloat: return "kFloat"; case AITemplateDtype::kInt: return "kInt"; case AITemplateDtype::kLong: return "kLong"; default: return "unknown"; } } } // namespace namespace ait { ModelContainer::ModelContainer( size_t num_models, size_t num_inputs, size_t num_outputs, size_t num_bound_constants, size_t num_unbound_constants, size_t params_size, AITemplateAllocator& allocator) : ModelContainerBase( num_inputs, num_outputs, num_bound_constants, num_unbound_constants, params_size, allocator), allocator_(allocator), num_inputs_(num_inputs), num_outputs_(num_outputs) { if (num_models == 0) { throw std::runtime_error("Number of models must be positive"); } dmlc::InitLogging("aitemplate"); // TODO(xxx): render network name int runtime_version; int driver_version; DEVICE_CHECK(GetDriverVersion(&driver_version)); DEVICE_CHECK(GetRuntimeVersion(&runtime_version)); LOG(INFO) << "Device Runtime Version: " << runtime_version << "; Driver Version: " << driver_version; int dev_id; DevicePropertyType prop; DEVICE_CHECK(GetDevice(&dev_id)); DEVICE_CHECK(GetDeviceProperties(&prop, dev_id)); bool useDebugLogging = false; if (auto var = std::getenv("LOGLEVEL")) { if (var[0] == 'd' || var[0] == 'D') { useDebugLogging = true; } } LOG(INFO) << (useDebugLogging ? PrintDebugDeviceProperties(prop) : PrintInfoDeviceProperties(prop)); LOG(INFO) << "Init AITemplate Runtime with " << num_models << " concurrency"; models_.reserve(num_models); available_models_.reserve(num_models); auto* constants_ptr = static_cast(constants_primary_.get()); for (size_t i = 0; i < num_models; ++i) { models_.push_back(Model::Create(allocator, constants_ptr)); available_models_.push_back(models_.back().get()); } constant_folder_ = ConstantFolder::Create(allocator, constants_ptr); // Wire up the constant folder's outputs to our constant buffer. size_t constant_idx = 0; for (auto offset : constant_folding_outputs_offsets_) { constant_folder_->SetOutput(constants_ptr + offset, constant_idx++); } } void ModelContainer::Run( const AITData* inputs, size_t num_inputs, AITData* outputs, size_t num_outputs, StreamType stream, bool sync, bool graph_mode, int64_t** output_shapes_out) { std::shared_lock constants_lk(constants_sync_mutex_); if (!constant_folded_once_) { // We don't require users to manually call FoldConstants the first time. // Note that if this throws (due to an unset constant, for example) // constant_folded_once_ will not be set. constants_lk.unlock(); std::unique_lock constants_unique_lk(constants_sync_mutex_); // Check again, another thread may have updated after we unlocked. if (!constant_folded_once_) { FoldConstantsImpl(stream); } constants_unique_lk.unlock(); constants_lk.lock(); } auto* model = GetAvailableModel(); try { PrepareForRun(model, inputs, num_inputs, outputs, num_outputs); model->Run(stream, graph_mode); } catch (...) { std::lock_guard lk(models_mutex_); available_models_.push_back(model); throw; } if (output_shapes_out) { for (size_t i = 0; i < num_outputs; ++i) { auto* out_shape = output_shapes_out[i]; model->GetOutputShape(i, out_shape); } } { std::lock_guard lk(models_mutex_); pending_models_.push_back(model); } pending_models_available_.notify_one(); if (sync) { StreamSynchronize(stream); } } void ModelContainer::Profile( const AITData* inputs, size_t num_inputs, AITData* outputs, size_t num_outputs, StreamType stream, size_t num_iters, const char* filename) { auto* model = GetAvailableModel(); if (filename == nullptr) { throw; } try { PrepareForRun(model, inputs, num_inputs, outputs, num_outputs); model->Profile(stream, num_iters, filename); } catch (...) { std::lock_guard lk(models_mutex_); available_models_.push_back(model); throw; } { std::lock_guard lk(models_mutex_); pending_models_.push_back(model); } pending_models_available_.notify_one(); } void ModelContainer::RunWithOutputsOnHost( const AITData* inputs, size_t num_inputs, AITData* outputs, size_t num_outputs, StreamType stream, bool graph_mode, int64_t** output_shapes_out) { std::vector> owned_outputs_ptrs; std::vector owned_outputs; owned_outputs_ptrs.reserve(num_outputs); owned_outputs.reserve(num_outputs); for (size_t i = 0; i < num_outputs; ++i) { size_t num_bytes = MaxOutputStorageBytes(i); owned_outputs_ptrs.emplace_back( RAII_DeviceMalloc(num_bytes, allocator_), num_bytes); owned_outputs.emplace_back( owned_outputs_ptrs.back().first.get(), outputs[i].shape, outputs[i].dtype); } Run(inputs, num_inputs, owned_outputs.data(), num_outputs, stream, /*sync=*/false, graph_mode, output_shapes_out); for (size_t i = 0; i < num_outputs; ++i) { auto& owned_output = owned_outputs_ptrs[i]; auto& ptr = owned_output.first; auto num_bytes = owned_output.second; DEVICE_CHECK(CopyToHost(outputs[i].ptr, ptr.get(), num_bytes, stream)); } DEVICE_CHECK(StreamSynchronize(stream)); } float ModelContainer::Benchmark( const AITData* inputs, size_t num_inputs, AITData* outputs, size_t num_outputs, StreamType stream, bool graph_mode, size_t count, size_t num_threads, bool use_unique_stream_per_thread, int64_t** output_shapes_out) { if (num_threads == 0) { num_threads = std::thread::hardware_concurrency(); } std::shared_lock constants_lk(constants_sync_mutex_); if (!constant_folded_once_) { constants_lk.unlock(); std::unique_lock constants_unique_lk(constants_sync_mutex_); // Check again, another thread may have updated after we unlocked. if (!constant_folded_once_) { FoldConstantsImpl(stream); } constants_unique_lk.unlock(); constants_lk.lock(); } if (num_threads == 1) { return BenchmarkImpl( inputs, num_inputs, outputs, num_outputs, stream, graph_mode, count, output_shapes_out) / count; } // Clone the outputs, each thread needs its own set std::vector> per_thread_outputs_ptrs; std::vector> per_thread_outputs; std::vector per_thread_streams; per_thread_outputs_ptrs.reserve(num_threads - 1); per_thread_outputs.reserve(num_threads - 1); if (use_unique_stream_per_thread) { per_thread_streams.reserve(num_threads); for (size_t i = 0; i < num_threads; ++i) { per_thread_streams.push_back(RAII_StreamCreate(/*non_blocking=*/true)); } } for (size_t i = 1; i < num_threads; ++i) { std::vector cloned_outputs_ptrs; std::vector cloned_outputs; cloned_outputs_ptrs.reserve(num_outputs); cloned_outputs.reserve(num_outputs); for (size_t j = 0; j < num_outputs; ++j) { size_t num_bytes = MaxOutputStorageBytes(j); cloned_outputs_ptrs.emplace_back( RAII_DeviceMalloc(num_bytes, allocator_)); auto* new_pointer = cloned_outputs_ptrs.back().get(); DEVICE_CHECK( DeviceToDeviceCopy(new_pointer, outputs[j].ptr, num_bytes, stream)); cloned_outputs.emplace_back( new_pointer, outputs[j].shape, outputs[j].dtype); } per_thread_outputs_ptrs.push_back(std::move(cloned_outputs_ptrs)); per_thread_outputs.push_back(std::move(cloned_outputs)); } DEVICE_CHECK(StreamSynchronize(stream)); auto get_stream = [stream, use_unique_stream_per_thread, &per_thread_streams]( size_t thread_idx) { if (!use_unique_stream_per_thread) { return stream; } return per_thread_streams[thread_idx].get(); }; auto thread_func = [&](size_t thread_idx) { AITData* thread_outputs = thread_idx == 0 ? outputs : per_thread_outputs[thread_idx - 1].data(); StreamType thread_stream = get_stream(thread_idx); auto* thread_output_shapes_out = thread_idx == 0 ? output_shapes_out : nullptr; return BenchmarkImpl( inputs, num_inputs, thread_outputs, num_outputs, thread_stream, graph_mode, count, thread_output_shapes_out); }; std::vector> futures; futures.reserve(num_threads); for (size_t i = 0; i < num_threads; ++i) { futures.push_back(std::async(std::launch::async, thread_func, i)); } auto max_time = std::accumulate( futures.begin(), futures.end(), 0.f, [](float cur_val, auto& future) { return std::max(future.get(), cur_val); }); // Verify that all the outputs are the same for (size_t i = 0; i < num_outputs; ++i) { auto output_size = MaxOutputStorageBytes(i); auto output_host = std::make_unique(output_size); // NB: technically, we don't have to copy to host here, but using // std::memcmp is easier than writing a kernel that does comparisons // for both backends, and performance is not important here. DEVICE_CHECK( CopyToHost(output_host.get(), outputs[i].ptr, output_size, stream)); DEVICE_CHECK(StreamSynchronize(stream)); for (size_t thread_idx = 1; thread_idx < num_threads; ++thread_idx) { auto* thread_output = per_thread_outputs[thread_idx - 1][i].ptr; auto thread_output_host = std::make_unique(output_size); auto thread_stream = get_stream(thread_idx); DEVICE_CHECK(CopyToHost( thread_output_host.get(), thread_output, output_size, thread_stream)); DEVICE_CHECK(StreamSynchronize(thread_stream)); if (std::memcmp( output_host.get(), thread_output_host.get(), output_size)) { throw std::runtime_error( "Output " + std::to_string(i) + " did not match for a spawned thread!"); } } } auto total_num_iters = num_threads * count; return max_time / total_num_iters; } void ModelContainer::SetConstantImpl( const char* name, const AITData& tensor, bool double_buffer, StreamType stream) { auto unbound_it = unbound_constant_name_to_idx_.find(name); auto bound_it = bound_constant_name_to_idx_.find(name); if (unbound_it != unbound_constant_name_to_idx_.end()) { auto constant_idx = unbound_it->second + num_inputs_ + num_outputs_; ValidateParamDtype(tensor.dtype, constant_idx); CHECK_VECTOR_ACCESS(max_param_storage_bytes_, constant_idx) auto expected_num_bytes = max_param_storage_bytes_[constant_idx]; auto actual_num_bytes = tensor.shape.Numel() * AITemplateDtypeSizeBytes(tensor.dtype); if (expected_num_bytes != actual_num_bytes) { throw std::runtime_error( std::string( "SetConstant did not receive correct number of bytes for unbound constant ") + name + ": expected " + std::to_string(expected_num_bytes) + " but got " + std::to_string(actual_num_bytes) + ". Check that the provided tensor's shape is correct."); } } else if (bound_it != bound_constant_name_to_idx_.end()) { auto constant_idx = bound_it->second; ValidateBoundConstantDtype(tensor.dtype, constant_idx); CHECK_VECTOR_ACCESS(bound_constant_size_, constant_idx) auto expected_num_bytes = bound_constant_size_[constant_idx]; auto actual_num_bytes = tensor.shape.Numel() * AITemplateDtypeSizeBytes(tensor.dtype); if (expected_num_bytes != actual_num_bytes) { throw std::runtime_error( std::string( "SetConstant did not receive correct number of bytes for bound constant ") + name + ": expected " + std::to_string(expected_num_bytes) + " but got " + std::to_string(actual_num_bytes) + ". Check that the provided tensor's shape is correct."); } } else { throw std::runtime_error( std::string("Called SetConstant on ") + name + std::string(" but can't find in either bound or unbound constant set")); } auto* src = tensor.ptr; bool is_constant_folder_ = constant_folding_inputs_.find(name) != constant_folding_inputs_.end() || constant_folding_optional_inputs_.find(name) != constant_folding_optional_inputs_.end(); if (!double_buffer) { // If we don't use double_buffer, we can just SetConstant. if (!is_constant_folder_) { for (auto& model : models_) { model->SetConstant(name, src); } } else { constant_folder_->SetConstant(name, src); } } else { // If we use double buffer, we identify whether it's a bounded constant or // not. If it's unbounded, just hold the pointer. It it's bounded, we copy // it into the constant buffer. if (unbound_it != unbound_constant_name_to_idx_.end()) { if (is_constant_folder_) { constant_folder_->SetConstant(name, src); } else { model_constants_[std::string(name)] = src; } } else { // Constant to be set is bounded, preload into constant buffer. uint8_t* constants_ptr = GetInactiveConstantsBuffer(); size_t idx = bound_it->second; // TODO: check whether src is host or device memory. DEVICE_CHECK(DeviceToDeviceCopy( constants_ptr + bound_constant_offsets_[idx], src, bound_constant_size_[idx], stream)); } } buffer_state_ = BufferState::CONSTANTS_UPDATED; } void ModelContainer::SetConstant(const char* name, const AITData& tensor) { std::lock_guard lk(constants_sync_mutex_); WaitForAllModels(/*include_constant_folder=*/true); SetConstantImpl(name, tensor); } void ModelContainer::SetManyConstants( const char** names, const AITData* tensors, size_t num_tensors) { if (num_tensors == 0) { return; } if (tensors == nullptr) { throw std::runtime_error("Tensor array cannot be null"); } std::lock_guard lk(constants_sync_mutex_); WaitForAllModels(/*include_constant_folder=*/true); for (size_t i = 0; i < num_tensors; ++i) { const char* name = names[i]; if (name == nullptr) { throw std::runtime_error("Constant name cannot be null"); } const auto& tensor = tensors[i]; SetConstantImpl(names[i], tensor); } } void ModelContainer::SwapConstantFolderBuffer() { uint8_t* constants_ptr = GetInactiveConstantsBuffer(); constant_folder_->ResetConstants(constants_ptr); size_t constant_idx = 0; for (auto offset : constant_folding_outputs_offsets_) { constant_folder_->SetOutput(constants_ptr + offset, constant_idx++); } } uint8_t* ModelContainer::GetInactiveConstantsBuffer() { uint8_t* constants_ptr{nullptr}; if (use_constants_primary_buffer_) { if (constants_secondary_ == nullptr) { constants_secondary_ = RAII_DeviceMalloc(constants_size_, allocator_); } constants_ptr = static_cast(constants_secondary_.get()); } else { constants_ptr = static_cast(constants_primary_.get()); } return constants_ptr; } void ModelContainer::SetDoubleBufferConstant( const char* name, const AITData& tensor, StreamType stream) { std::lock_guard lk(constants_double_buffer_mutex_); SetConstantImpl(name, tensor, /* double_buffer */ true, stream); } void ModelContainer::SetManyDoubleBufferConstants( const char** names, const AITData* tensors, size_t num_tensors, StreamType stream) { if (num_tensors == 0) { return; } if (tensors == nullptr) { throw std::runtime_error("Tensor array cannot be null"); } std::lock_guard lk(constants_double_buffer_mutex_); for (size_t i = 0; i < num_tensors; ++i) { const char* name = names[i]; if (name == nullptr) { throw std::runtime_error("Constant name cannot be null"); } const auto& tensor = tensors[i]; SetConstantImpl(names[i], tensor, /* double_buffer */ true, stream); } } size_t ModelContainer::NumInputs() const { return num_inputs_; } const char* ModelContainer::InputName(size_t input_idx) const { CHECK(input_idx < num_inputs_); CHECK_VECTOR_ACCESS(param_names_, input_idx) return param_names_[input_idx]; } AITemplateParamShape ModelContainer::MaxInputShape(size_t input_idx) const { CHECK(input_idx < num_inputs_); CHECK_VECTOR_ACCESS(max_param_shapes_, input_idx) auto& input_shape = max_param_shapes_[input_idx]; return AITemplateParamShape{input_shape.data(), input_shape.size()}; } AITemplateDtype ModelContainer::InputDtype(size_t input_idx) const { CHECK(input_idx < num_inputs_); CHECK_VECTOR_ACCESS(param_dtypes_, input_idx) return param_dtypes_[input_idx]; } size_t ModelContainer::NumOutputs() const { return num_outputs_; } const char* ModelContainer::OutputName(size_t output_idx) const { auto idx = output_idx + num_inputs_; CHECK_VECTOR_ACCESS(param_names_, idx) return param_names_[idx]; } AITemplateParamShape ModelContainer::MaxOutputShape(size_t output_idx) const { auto idx = output_idx + num_inputs_; CHECK_VECTOR_ACCESS(max_param_shapes_, idx) auto& out_shape = max_param_shapes_[idx]; return AITemplateParamShape{out_shape.data(), out_shape.size()}; } AITemplateDtype ModelContainer::OutputDtype(size_t output_idx) const { auto idx = output_idx + num_inputs_; CHECK_VECTOR_ACCESS(param_dtypes_, idx) return param_dtypes_[idx]; } size_t ModelContainer::MaxOutputStorageBytes(size_t output_idx) const { auto idx = output_idx + num_inputs_; CHECK_VECTOR_ACCESS(max_param_storage_bytes_, idx) return max_param_storage_bytes_[idx]; } void ModelContainer::WaitForAllModels(bool include_constant_folder) { // Wait for all on-going inferences to finish. for (auto* model : pending_models_) { try { model->WaitForCompletion(); // Something has gone horribly wrong if we hit these catch cases, but // there's not much we can do about it. Just put the model back into the // pool and carry on with folding. } catch (std::exception& e) { LOG(WARNING) << "Model threw exception when waiting for inference to finish: " << e.what() << ". Ignoring and continuing constant folding."; } catch (...) { LOG(WARNING) << "Model threw unknown exception when waiting for inference to finish. Ignoring and continuing constant foldng."; } available_models_.push_back(model); } if (include_constant_folder) { try { constant_folder_->WaitForCompletion(); } catch (...) { LOG(WARNING) << "Constant folder threw exception while waiting for completion, ignoring."; } } } void ModelContainer::FoldConstantsImpl(StreamType stream, bool double_buffer) { if (constant_folded_once_) { // We do not set the buffer state if this is the initial constant folding. buffer_state_ = BufferState::CONSTANTS_FOLDED; } if (double_buffer) { SwapConstantFolderBuffer(); } else { // NB: No need to acquire models_mutex_ here. We're guaranteed that nothing // will be concurrently messing with the Model vectors while we hold // the constants_sync_mutex_ in unique mode. See model_container.h for // the full explanation. WaitForAllModels(); } // We might have already started constant folding, make sure it finishes // first. It's OK if we throw here, there's no state to restore. // We just won't finish the folding and will need to do it again. constant_folder_->WaitForCompletion(); if (double_buffer) { std::lock_guard constants_unique_lk(constants_double_buffer_mutex_); constant_folder_->Run(stream, /*graph_mode=*/false); } else { constant_folder_->Run(stream, /*graph_mode=*/false); } constant_folded_once_ = true; } void ModelContainer::FoldConstants( StreamType stream, bool sync, bool double_buffer) { if (double_buffer) { FoldConstantsImpl(stream, double_buffer); } else { std::lock_guard constant_folding_lk(constants_sync_mutex_); FoldConstantsImpl(stream); } if (sync) { DEVICE_CHECK(StreamSynchronize(stream)); } } void ModelContainer::SwapConstants() { if (buffer_state_ != BufferState::CONSTANTS_FOLDED) { LOG(WARNING) << "Called SwapConstants without calling FoldConstants()."; return; } std::unique_lock constants_unique_lk(constants_double_buffer_mutex_); uint8_t* constants_ptr = GetInactiveConstantsBuffer(); use_constants_primary_buffer_ = !use_constants_primary_buffer_; for (auto& model : models_) { model->ResetConstants(constants_ptr); } for (auto& [name, src] : model_constants_) { for (auto& model : models_) { model->SetConstant(name.c_str(), src); } } model_constants_.clear(); buffer_state_ = BufferState::CLEAN; } size_t ModelContainer::GetNumConstants(bool unbound_constants_only) const { if (unbound_constants_only) { return unbound_constant_name_to_idx_.size(); } else { return unbound_constant_name_to_idx_.size() + bound_constant_name_to_idx_.size(); } } size_t ModelContainer::GetNumConstantFoldingInputs( bool unbound_constants_only) const { if (unbound_constants_only) { return constant_folding_inputs_.size(); } else { return constant_folding_inputs_.size() + constant_folding_optional_inputs_.size(); } } void ModelContainer::WriteAllConstantNamesTo( const char** constant_names_out, bool unbound_constants_only, bool constant_folding_inputs_only) const { size_t num_to_write = constant_folding_inputs_only ? GetNumConstants(unbound_constants_only) : GetNumConstantFoldingInputs(unbound_constants_only); if (constant_names_out == nullptr && num_to_write != 0) { throw std::runtime_error("constant_names_out cannot be nullptr."); } size_t idx = 0; for (auto& [name, _] : unbound_constant_name_to_idx_) { if (!constant_folding_inputs_only || constant_folding_inputs_.find(name) != constant_folding_inputs_.end()) { constant_names_out[idx++] = name.c_str(); } } if (!unbound_constants_only) { for (auto& [name, _] : bound_constant_name_to_idx_) { if (!constant_folding_inputs_only || constant_folding_optional_inputs_.find(name) != constant_folding_optional_inputs_.end()) { constant_names_out[idx++] = name.c_str(); } } } } void ModelContainer::PrepareForRun( Model* model, const AITData* inputs, size_t num_inputs, AITData* outputs, size_t num_outputs) { if (num_inputs != num_inputs_) { auto msg = "Got wrong number of inputs; expected " + std::to_string(num_inputs_) + ", got " + std::to_string(num_inputs); throw std::runtime_error(std::move(msg)); } if (num_inputs > 0 && inputs == nullptr) { throw std::runtime_error("inputs cannot be null"); } if (num_outputs != num_outputs_) { auto msg = "Got wrong number of outputs; expected " + std::to_string(num_outputs_) + ", got " + std::to_string(num_outputs); throw std::runtime_error(std::move(msg)); } if (num_outputs > 0 && outputs == nullptr) { throw std::runtime_error("outputs cannot be null"); } for (size_t i = 0; i < num_inputs_; ++i) { auto& input = inputs[i]; ValidateParamDtype(input.dtype, i); model->SetInput(input.ptr, input.shape, i); } for (size_t i = 0; i < num_outputs_; ++i) { auto& output = outputs[i]; ValidateParamDtype(output.dtype, i + num_inputs_); model->SetOutput(output.ptr, i); } } Model* ModelContainer::GetAvailableModel() { std::unique_lock lk(models_mutex_); if (available_models_.empty()) { ReclaimFinishedModels(lk); } auto* result = available_models_.back(); available_models_.pop_back(); return result; } void ModelContainer::ReclaimFinishedModels(std::unique_lock& lk) { // Put any complete models at the end auto it = std::stable_partition( pending_models_.begin(), pending_models_.end(), [](Model* m) { return m->IsPending(); }); if (it != pending_models_.end()) { // Move all available models to the pool. available_models_.insert( available_models_.end(), it, pending_models_.end()); pending_models_.erase(it, pending_models_.end()); return; } pending_models_available_.wait( lk, [this]() { return !pending_models_.empty(); }); // There are no available workspaces! We have to wait on one. auto* model = pending_models_.front(); pending_models_.pop_front(); lk.unlock(); try { model->WaitForCompletion(); } catch (...) { lk.lock(); available_models_.push_back(model); throw; } lk.lock(); available_models_.push_back(model); } void ModelContainer::ValidateParamDtype(AITemplateDtype dtype, size_t idx) const { CHECK_VECTOR_ACCESS(param_dtypes_, idx) if (dtype != param_dtypes_[idx]) { throw std::runtime_error( "Got wrong dtype for param " + std::to_string(idx) + "; expected " + GetEnumString(param_dtypes_[idx]) + ", got " + GetEnumString(dtype)); } } void ModelContainer::ValidateBoundConstantDtype( AITemplateDtype dtype, size_t idx) const { CHECK_VECTOR_ACCESS(bound_constant_dtypes_, idx) if (dtype != bound_constant_dtypes_[idx]) { throw std::runtime_error( "Got wrong dtype for param " + std::to_string(idx) + "; expected " + GetEnumString(bound_constant_dtypes_[idx]) + ", got " + GetEnumString(dtype)); } } float ModelContainer::BenchmarkImpl( const AITData* inputs, size_t num_inputs, AITData* outputs, size_t num_outputs, StreamType stream, bool graph_mode, size_t count, int64_t** output_shapes_out) { auto* model = GetAvailableModel(); float runtime_ms = 0.; auto start_event = RAII_CreateEvent(); auto end_event = RAII_CreateEvent(); try { PrepareForRun(model, inputs, num_inputs, outputs, num_outputs); DEVICE_CHECK(EventRecord(start_event.get(), stream)); for (size_t i = 0; i < count; ++i) { model->Run(stream, graph_mode); } } catch (...) { std::lock_guard lk(models_mutex_); available_models_.push_back(model); throw; } if (output_shapes_out) { for (size_t i = 0; i < num_outputs; ++i) { auto* out_shape = output_shapes_out[i]; model->GetOutputShape(i, out_shape); } } // Push the model back into the pool before synchronizing the event // to exercise the concurrency code { std::lock_guard lk(models_mutex_); pending_models_.push_back(model); } pending_models_available_.notify_one(); DEVICE_CHECK(EventRecord(end_event.get(), stream)); DEVICE_CHECK(EventSynchronize(end_event.get())); DEVICE_CHECK( EventElapsedTime(&runtime_ms, start_event.get(), end_event.get())); LOG(INFO) << "Benchmark runtime ms/iter: " << runtime_ms / count; return runtime_ms; } } // namespace ait