| | #pragma once |
| |
|
| | #include "../utils/exception.hpp" |
| | #include "../utils/format.hpp" |
| | #include "../utils/system.hpp" |
| | #include "device_runtime.hpp" |
| | #include "handle.hpp" |
| |
|
| | namespace deep_gemm { |
| |
|
| | struct LaunchArgs { |
| | std::pair<int, int> grid_dim; |
| | int num_threads; |
| | int smem_size; |
| | int cluster_dim; |
| |
|
| | LaunchArgs(const int& grid_dim_x, const int& num_threads, const int& smem_size = 0, const int& cluster_dim = 1): |
| | grid_dim({grid_dim_x, 1}), num_threads(num_threads), smem_size(smem_size), cluster_dim(cluster_dim) {} |
| |
|
| | LaunchArgs(const std::pair<int, int>& grid_dim, const int& num_threads, const int& smem_size = 0, const int& cluster_dim = 1): |
| | grid_dim(grid_dim), num_threads(num_threads), smem_size(smem_size), cluster_dim(cluster_dim) {} |
| | }; |
| |
|
| | class KernelRuntime final { |
| | public: |
| | static std::filesystem::path cuda_home; |
| |
|
| | LibraryHandle library; |
| | KernelHandle kernel; |
| |
|
| | explicit KernelRuntime(const std::filesystem::path& dir_path) { |
| | |
| | DG_HOST_ASSERT(not cuda_home.empty()); |
| |
|
| | |
| | const auto& cuobjdump_path = cuda_home / "bin" / "cuobjdump"; |
| | const auto& cubin_path = dir_path / "kernel.cubin"; |
| | if (get_env<int>("DG_JIT_DEBUG")) |
| | printf("Loading CUBIN: %s\n", cubin_path.c_str()); |
| |
|
| | |
| | |
| | const std::vector<std::string> illegal_names = {"vprintf", "__instantiate_kernel", "__internal", "__assertfail"}; |
| | const auto& [exit_code, symbols] = call_external_command(fmt::format("{} -symbols {}", cuobjdump_path.c_str(), cubin_path.c_str())); |
| | DG_HOST_ASSERT(exit_code == 0); |
| | std::istringstream iss(symbols); |
| | std::vector<std::string> symbol_names; |
| | for (std::string line; std::getline(iss, line); ) { |
| | if (line.find("STT_FUNC") == 0 and line.find("STO_ENTRY") != std::string::npos and |
| | std::none_of(illegal_names.begin(), illegal_names.end(), |
| | [&](const auto& name) { return line.find(name) != std::string::npos; })) { |
| | const auto& last_space = line.rfind(' '); |
| | symbol_names.push_back(line.substr(last_space + 1)); |
| | } |
| | } |
| | if (get_env<int>("DG_JIT_DEBUG")) { |
| | printf("Symbol names: "); |
| | for (const auto& symbol: symbol_names) |
| | printf("%s, ", symbol.c_str()); |
| | printf("\n"); |
| | } |
| |
|
| | |
| | DG_HOST_ASSERT(symbol_names.size() == 1); |
| | kernel = load_kernel(cubin_path, symbol_names[0], &library); |
| | } |
| |
|
| | static void prepare_init(const std::string& cuda_home_path_by_python) { |
| | cuda_home = cuda_home_path_by_python; |
| | } |
| |
|
| | static bool check_validity(const std::filesystem::path& dir_path) { |
| | return std::filesystem::exists(dir_path / "kernel.cu") and |
| | std::filesystem::exists(dir_path / "kernel.cubin"); |
| | } |
| |
|
| | ~KernelRuntime() noexcept(false) { |
| | unload_library(library); |
| | } |
| | }; |
| |
|
| | DG_DECLARE_STATIC_VAR_IN_CLASS(KernelRuntime, cuda_home); |
| |
|
| | template <typename Derived> |
| | class LaunchRuntime { |
| | public: |
| | template <typename Args> |
| | static std::string generate(const Args& args) { |
| | const auto& code = Derived::generate_impl(args); |
| | if (get_env<int>("DG_JIT_DEBUG", 0)) |
| | printf("Generated kernel code: %s\n", code.c_str()); |
| | return code; |
| | } |
| |
|
| | template <typename Args> |
| | static void launch(const std::shared_ptr<KernelRuntime>& kernel_runtime, const Args& args) { |
| | const auto& kernel = kernel_runtime->kernel; |
| | const auto& stream = at::cuda::getCurrentCUDAStream(); |
| | const LaunchArgs& launch_args = args.launch_args; |
| |
|
| | const dim3& grid_dim = {static_cast<unsigned>(launch_args.grid_dim.first), |
| | static_cast<unsigned>(launch_args.grid_dim.second), |
| | 1}; |
| | const dim3& block_dim = {static_cast<unsigned>(launch_args.num_threads), 1, 1}; |
| | auto config = construct_launch_config(kernel, stream, launch_args.smem_size, |
| | grid_dim, block_dim, launch_args.cluster_dim); |
| |
|
| | |
| | if (get_env<int>("DG_JIT_DEBUG")) { |
| | printf("Launch kernel with {%d, %d} x %d, shared memory: %d bytes, cluster: %d, stream: %ld\n", |
| | launch_args.grid_dim.first, launch_args.grid_dim.second, launch_args.num_threads, |
| | launch_args.smem_size, launch_args.cluster_dim, stream.id()); |
| | } |
| | Derived::launch_impl(kernel, config, args); |
| | } |
| | }; |
| |
|
| | } |
| |
|