// Copyright (c) Meta Platforms, Inc. and affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // #pragma once #include #include namespace ait { inline void DeviceCheckLastError(const char* file, int line) { auto device_error = GetLastError(); if (device_error != GetDeviceSuccess()) { std::string msg = std::string("Got error: ") + GetLastErrorString() + " enum: " + std::to_string(device_error) + " at " + file + ": " + std::to_string(line); LOG(ERROR) << msg; throw std::runtime_error(msg); } } // This serves as a base class for AIT runtime objects, e.g. the compiled // model and the constant folder. It uses CRTP as a mechanism to call into // a few base class methods (dynamic dispatch is not needed in ModelContainer, // so there's no need to add a vtable). Inheriting classes should implement // the following methods: // - RunImpl(StreamType): The bulk of the compiled model's kernel invocations // go here. // - SetUpInputsOutputs(): Check the provided input/output pointers dtypes & // sizes // - DeviceToDeviceCopies(): Called at the end of infernece, copy views of // inputs/constants to the provided output pointer. // // In practice, inheriting classes are generated via MODEL_TEMPLATE in // python/aitemplate/backend/main_templates.py. template class ModelBase { protected: // Should not be constructed directly, use the base class' factory function // instead. ModelBase( size_t blob_size, size_t workspace_size, size_t unique_workspace_size, size_t num_inputs, size_t num_outputs, size_t num_unbound_constants, uint8_t* constants, AITemplateAllocator& allocator) : blob_(RAII_DeviceMalloc(blob_size, allocator)), workspace_(RAII_DeviceMalloc(workspace_size, allocator)), params_(num_inputs + num_outputs + num_unbound_constants), num_inputs_(num_inputs), num_outputs_(num_outputs), constants_(constants) { global_workspace_ = static_cast(workspace_.get()) + unique_workspace_size; unique_workspace_ = static_cast(workspace_.get()); DEVICE_CHECK(GetDevice(&device_idx_)) DEVICE_CHECK(CreateEvent(&run_finished_)); #if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) DEVICE_CHECK(cudaDeviceGetAttribute( &max_smem_size_, cudaDevAttrMaxSharedMemoryPerBlockOptin, device_idx_)); #endif DEVICE_CHECK(GetDeviceProperties(&device_properties_, device_idx_)); DEVICE_CHECK(StreamCreate(&graph_capture_stream_, /*non_blocking=*/true)); } public: ~ModelBase() { if (run_finished_ != nullptr) { DestroyEvent(run_finished_); } if (graph_capture_stream_ != nullptr) { StreamDestroy(graph_capture_stream_); } if (graph_exec_ != nullptr) { GraphExecDestroy(graph_exec_); } } ModelBase(ModelBase&&) = delete; ModelBase& operator=(ModelBase&&) = delete; ModelBase(const ModelBase&) = delete; ModelBase& operator=(const ModelBase&) = delete; void Run(StreamType stream, bool graph_mode) { auto* model = static_cast(this); model->SetUpInputsOutputs(); if (target_has_graph_mode && graph_mode) { RunAsGraph(stream); } else { model->RunImpl(stream); } model->DeviceToDeviceCopies(stream); DEVICE_CHECK(EventRecord(run_finished_, stream)); } void Profile(StreamType stream, size_t iters, const std::string& filename) { auto* model = static_cast(this); model->SetUpInputsOutputs(); model->ProfileImpl(stream, iters, filename); } bool IsPending() { auto query = QueryEvent(run_finished_); if (query == GetDeviceNotReady()) { return true; } if (query != GetDeviceSuccess()) { LOG(WARNING) << "Pending model run did not finish successfully. Error: " << GetErrorString(query); } return false; } void WaitForCompletion() { DEVICE_CHECK(EventSynchronize(run_finished_)); } size_t NumInputs() const { return num_inputs_; } size_t NumOutputs() const { return num_outputs_; } void SetParam(const void* src, size_t param_idx) { CHECK_VECTOR_ACCESS(params_, param_idx) // const_cast is not ideal here, but it is unfortunately // necessary: // 1) We store outputs and inputs in the same vector, // and outputs cannot be const. // 2) Most of the codegen is not const-correct (most ops // require non-const pointers). So even if we put const // pointers into params, a const_cast would be required // somewhere else. params_[param_idx].ptr = const_cast(src); } void SetInput( const void* src, const AITemplateParamShape& shape, size_t idx) { SetInputShape(shape, idx); SetParam(src, idx); } void SetOutput(void* src, size_t idx) { SetParam(src, idx + num_inputs_); } // Write the (possibly dynamic) output shape to the given pointer. // Note that this should be called _after_ the shape inference in // Run() is finished. output_shape_out should be able to store // at least GetOutputMaximumShape(idx).size values. void GetOutputShape(size_t idx, int64_t* output_shape_out) { const auto param_idx = idx + num_inputs_; CHECK_VECTOR_ACCESS(params_, param_idx); const auto& shape_ptrs = params_[param_idx].shape_ptrs; for (size_t i = 0; i < shape_ptrs.size(); ++i) { output_shape_out[i] = shape_ptrs[i].GetValue(); } } void SetConstant(const char* name, const void* src) { auto it = constant_name_to_ptr_.find(name); if (it == constant_name_to_ptr_.end()) { throw std::out_of_range(std::string("Could not find constant ") + name); } const void** ptr = it->second; *ptr = src; } private: void SetInputShape(const AITemplateParamShape& shape, size_t idx) { auto& param = params_[idx]; if (shape.size != param.shape_ptrs.size()) { throw std::runtime_error( "[SetInputShape] Got wrong param shape for input " + std::to_string(idx) + "; expected " + std::to_string(param.shape_ptrs.size()) + ", got " + std::to_string(shape.size)); } for (size_t i = 0; i < param.shape_ptrs.size(); ++i) { param.shape_ptrs[i].SetValue(shape.shape_data[i]); } } DeviceError EndCapture(GraphType* graph_ptr) { auto err = StreamEndCapture(graph_capture_stream_, graph_ptr); if (err != GetDeviceSuccess()) { // If we can't take the stream out of capture mode, something is probably // wrong with CUDA graph for this model (e.g. there might have been an // illegal capture mode operation). Disable graph mode to avoid such // issues in future iterations. target_has_graph_mode = false; LOG(WARNING) << "Graph capture failed to end. Disabling graph mode."; return err; } return GetDeviceSuccess(); } void RunAsGraph(StreamType stream) { DEVICE_CHECK(StreamBeginCapture(graph_capture_stream_, /*global=*/false)); try { static_cast(this)->RunImpl(graph_capture_stream_); } catch (...) { GraphType graph; // No need to DEVICE_CHECK here, we want to see the original exception. EndCapture(&graph); if (graph != nullptr && GraphDestroy(graph) != GetDeviceSuccess()) { LOG(WARNING) << "Graph destruction failed while handling exception! Memory will be leaked."; } throw; } // The following function ends the capture and creates a graph // inside a unique_ptr that cleans up it when it goes out of scope. // Note that it throws an exception if EndCapture fails. auto graph = RAII_EndCaptureAndCreateGraph( [this](GraphType* graph_ptr) { return EndCapture(graph_ptr); }); if (graph_exec_ == nullptr) { DEVICE_CHECK(GraphInstantiate(&graph_exec_, graph.get())); } else if ( GraphExecUpdate(graph_exec_, graph.get()) != GetDeviceSuccess()) { // Consume the last cuda error, which may affect the next GraphExecLaunch // call. GetLastError(); DEVICE_CHECK(GraphExecDestroy(graph_exec_)); DEVICE_CHECK(GraphInstantiate(&graph_exec_, graph.get())); } DEVICE_CHECK(GraphExecLaunch(graph_exec_, stream)); } protected: int device_idx_; int max_smem_size_{0}; DevicePropertyType device_properties_; // This event tracks when the inference is finished // so that this Model may be reclaimed by its owning // ModelContainer. EventType run_finished_; // A blob of memory used for storing intermediate tensors. GPUPtr blob_; // Memory for constants that were folded into the *.so. Unowned by Model, // owned by ModelContainer. // TODO: make this const. It can't be const right now because we derive // tensor pointers from it, and no tensor pointers are const. uint8_t* constants_; size_t num_inputs_; size_t num_outputs_; // The workspace blob is used as scratch memory. See // _generate_workspace in memory planning for more information. GPUPtr workspace_; uint8_t* global_workspace_{nullptr}; uint8_t* unique_workspace_{nullptr}; class ParamDim { public: ParamDim(int64_t lower_bound, int64_t upper_bound, int64_t* value) : lower_bound_(lower_bound), upper_bound_(upper_bound), value_(value) {} void SetValue(int64_t new_value) { if (new_value < lower_bound_ || new_value > upper_bound_) { throw std::out_of_range( "[SetValue] Dimension got value out of bounds; expected value to be in [" + std::to_string(lower_bound_) + ", " + std::to_string(upper_bound_) + "], but got " + std::to_string(new_value)); } *value_ = new_value; } int64_t GetValue() const { return *value_; } private: int64_t lower_bound_; int64_t upper_bound_; int64_t* value_; }; struct ParamInfo { void* ptr = nullptr; // TODO add offset const char* name; std::vector shape_ptrs; }; // Contains info for all tensors marked as inputs // or outputs. The first num_inputs elements are the inputs. // Constants are not included. std::vector params_; GraphExecType graph_exec_ = nullptr; StreamType graph_capture_stream_; std::unordered_map constant_name_to_ptr_; }; } // namespace ait