// Copyright (c) Meta Platforms, Inc. and affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // #pragma once // Some helpful unique_ptr instantiations and factory functions for CUDA types #include #include #include #include "device_functions-generated.h" #include "macros.h" namespace ait { // RAII wrapper for owned GPU memory. Not that the underlying calls // to malloc/free are synchronous for simplicity. using GPUPtr = std::unique_ptr>; using StreamPtr = std:: unique_ptr::type, decltype(&StreamDestroy)>; using EventPtr = std:: unique_ptr::type, decltype(&DestroyEvent)>; using GraphPtr = std::unique_ptr< std::remove_pointer::type, std::function>; inline GPUPtr RAII_DeviceMalloc( size_t num_bytes, AITemplateAllocator& allocator) { auto* output = allocator.Allocate(num_bytes); auto deleter = [&allocator](void* ptr) mutable { allocator.Free(ptr); }; return GPUPtr(output, deleter); } inline StreamPtr RAII_StreamCreate(bool non_blocking = false) { StreamType stream; DEVICE_CHECK(StreamCreate(&stream, non_blocking)); return StreamPtr(stream, StreamDestroy); } inline EventPtr RAII_CreateEvent() { EventType event; DEVICE_CHECK(CreateEvent(&event)); return EventPtr(event, DestroyEvent); } inline GraphPtr RAII_EndCaptureAndCreateGraph( const std::function& end_capture_fn) { GraphType graph; // If this throws, we shouldn't leak memory. cudaGraphEndCapture is guaranteed // to return the NULL graph if ending the stream capture doesn't work. // We pass a custom function here instead of calling StreamEndCapture // directly so classes can manipulate state if the stream capture fails // (e.g. disabling graph mode might be useful in that case). DEVICE_CHECK(end_capture_fn(&graph)) return GraphPtr(graph, GraphDestroy); } class RAII_ProfilerRange { public: RAII_ProfilerRange(char* name) { ProfilerRangePush(name); } ~RAII_ProfilerRange() { ProfilerRangePop(); } RAII_ProfilerRange(const RAII_ProfilerRange&) = delete; RAII_ProfilerRange(RAII_ProfilerRange&&) = delete; }; } // namespace ait