| #pragma once |
|
|
| #include <ATen/cuda/CUDAContext.h> |
| #include <cuda_runtime.h> |
| #include <dlfcn.h> |
| #include <filesystem> |
| #include <fstream> |
| #include <nvrtc.h> |
| #include <regex> |
| #include <string> |
|
|
| #include "../utils/exception.hpp" |
| #include "../utils/format.hpp" |
| #include "../utils/hash.hpp" |
| #include "../utils/lazy_init.hpp" |
| #include "../utils/system.hpp" |
| #include "cache.hpp" |
| #include "device_runtime.hpp" |
|
|
| namespace deep_gemm { |
|
|
| |
| static void* get_nvrtc_handle() { |
| static void* handle = nullptr; |
| if (handle == nullptr) { |
| handle = dlopen("libnvrtc.so", RTLD_LAZY | RTLD_LOCAL); |
| if (handle == nullptr) |
| handle = dlopen("libnvrtc.so.12", RTLD_LAZY | RTLD_LOCAL); |
| if (handle == nullptr) |
| handle = dlopen("libnvrtc.so.11.2", RTLD_LAZY | RTLD_LOCAL); |
| DG_HOST_ASSERT(handle != nullptr and "Failed to load NVRTC library (libnvrtc.so)"); |
| } |
| return handle; |
| } |
|
|
| #define DECL_LAZY_NVRTC_FUNCTION(name) \ |
| template <typename... Args> \ |
| static auto lazy_##name(Args&&... args) -> decltype(name(args...)) { \ |
| using FuncType = decltype(&name); \ |
| static FuncType func = nullptr; \ |
| if (func == nullptr) { \ |
| func = reinterpret_cast<FuncType>(dlsym(get_nvrtc_handle(), #name)); \ |
| DG_HOST_ASSERT(func != nullptr and "Failed to load NVRTC function " #name); \ |
| } \ |
| return func(std::forward<decltype(args)>(args)...); \ |
| } |
|
|
| DECL_LAZY_NVRTC_FUNCTION(nvrtcVersion); |
| DECL_LAZY_NVRTC_FUNCTION(nvrtcCreateProgram); |
| DECL_LAZY_NVRTC_FUNCTION(nvrtcCompileProgram); |
| DECL_LAZY_NVRTC_FUNCTION(nvrtcGetProgramLogSize); |
| DECL_LAZY_NVRTC_FUNCTION(nvrtcGetProgramLog); |
| DECL_LAZY_NVRTC_FUNCTION(nvrtcGetPTXSize); |
| DECL_LAZY_NVRTC_FUNCTION(nvrtcGetPTX); |
| DECL_LAZY_NVRTC_FUNCTION(nvrtcGetCUBINSize); |
| DECL_LAZY_NVRTC_FUNCTION(nvrtcGetCUBIN); |
| DECL_LAZY_NVRTC_FUNCTION(nvrtcDestroyProgram); |
| DECL_LAZY_NVRTC_FUNCTION(nvrtcGetErrorString); |
|
|
| #define DG_NVRTC_CHECK(cmd) \ |
| do { \ |
| const auto& e = (cmd); \ |
| if (e != NVRTC_SUCCESS) { \ |
| throw DGException("NVRTC", __FILE__, __LINE__, lazy_nvrtcGetErrorString(e)); \ |
| } \ |
| } while (0) |
|
|
| class Compiler { |
| public: |
| static std::filesystem::path library_root_path; |
| static std::filesystem::path library_include_path; |
| static std::filesystem::path cutlass_include_path; |
| static std::filesystem::path cuda_home; |
| static std::string library_version; |
| static std::filesystem::path cuobjdump_path; |
|
|
| static std::string get_library_version() { |
| std::vector<char> buffer; |
| for (const auto& f: collect_files(library_include_path / "deep_gemm")) { |
| std::ifstream in(f, std::ios::binary); |
| DG_HOST_ASSERT(in.is_open()); |
|
|
| |
| buffer.insert(buffer.end(), |
| std::istreambuf_iterator<char>(in), |
| std::istreambuf_iterator<char>()); |
| } |
| return get_hex_digest(buffer); |
| } |
|
|
| static void prepare_init(const std::string& library_root_path, |
| const std::string& cuda_home_path_by_python, |
| const std::string& cutlass_include_path_by_python = "") { |
| Compiler::library_root_path = library_root_path; |
| Compiler::library_include_path = Compiler::library_root_path / "include"; |
| Compiler::cuda_home = cuda_home_path_by_python; |
| Compiler::library_version = get_library_version(); |
| Compiler::cuobjdump_path = Compiler::cuda_home / "bin" / "cuobjdump"; |
|
|
| |
| if (!cutlass_include_path_by_python.empty()) { |
| Compiler::cutlass_include_path = cutlass_include_path_by_python; |
| } else { |
| auto env_path = get_env<std::string>("DG_CUTLASS_INCLUDE"); |
| if (!env_path.empty()) { |
| Compiler::cutlass_include_path = env_path; |
| } else { |
| auto cutlass_home = get_env<std::string>("CUTLASS_HOME"); |
| if (!cutlass_home.empty()) { |
| Compiler::cutlass_include_path = std::filesystem::path(cutlass_home) / "include"; |
| } |
| } |
| } |
| } |
|
|
| std::string signature, flags; |
| std::filesystem::path cache_dir_path; |
|
|
| Compiler() { |
| |
| DG_HOST_ASSERT(not library_root_path.empty()); |
| DG_HOST_ASSERT(not library_include_path.empty()); |
| DG_HOST_ASSERT(not cuda_home.empty()); |
| DG_HOST_ASSERT(not library_version.empty()); |
| DG_HOST_ASSERT(not cuobjdump_path.empty()); |
|
|
| |
| cache_dir_path = std::filesystem::path(get_env<std::string>("HOME")) / ".deep_gemm"; |
| if (const auto& env_cache_dir_path = get_env<std::string>("DG_JIT_CACHE_DIR"); not env_cache_dir_path.empty()) |
| cache_dir_path = env_cache_dir_path; |
|
|
| |
| signature = "unknown-compiler"; |
| flags = fmt::format("-std=c++{} --diag-suppress=39,161,174,177,186,940 " |
| "--ptxas-options=--register-usage-level=10", |
| get_env<int>("DG_JIT_CPP_STANDARD", 20)); |
| if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PTXAS_VERBOSE", 0) or get_env("DG_JIT_PTXAS_CHECK", 0)) |
| flags += " --ptxas-options=--verbose,--warn-on-local-memory-usage"; |
| if (get_env("DG_JIT_WITH_LINEINFO", 0)) |
| flags += " -Xcompiler -rdynamic -lineinfo"; |
| } |
|
|
| virtual ~Compiler() = default; |
|
|
| std::filesystem::path make_tmp_dir() const { |
| return make_dirs(cache_dir_path / "tmp"); |
| } |
|
|
| std::filesystem::path get_tmp_file_path() const { |
| return make_tmp_dir() / get_uuid(); |
| } |
|
|
| void put(const std::filesystem::path& path, const std::string& data) const { |
| const auto tmp_file_path = get_tmp_file_path(); |
|
|
| |
| std::ofstream out(tmp_file_path, std::ios::binary); |
| DG_HOST_ASSERT(out.write(data.data(), data.size())); |
| out.close(); |
|
|
| |
| std::filesystem::rename(tmp_file_path, path); |
| } |
|
|
| std::shared_ptr<KernelRuntime> build(const std::string& name, const std::string& code) const { |
| const auto kernel_signature = fmt::format("{}$${}$${}$${}$${}", name, library_version, signature, flags, code); |
| const auto dir_path = cache_dir_path / "cache" / fmt::format("kernel.{}.{}", name, get_hex_digest(kernel_signature)); |
|
|
| |
| if (const auto& runtime = kernel_runtime_cache->get(dir_path); runtime != nullptr) |
| return runtime; |
|
|
| |
| make_dirs(dir_path); |
|
|
| |
| const auto tmp_cubin_path = get_tmp_file_path(); |
| if (get_env<int>("DG_JIT_DUMP_ASM") or get_env<int>("DG_JIT_DUMP_PTX")) { |
| |
| const auto tmp_ptx_path = get_tmp_file_path(); |
| compile(code, dir_path, tmp_cubin_path, tmp_ptx_path); |
|
|
| |
| std::filesystem::rename(tmp_ptx_path, dir_path / "kernel.ptx"); |
| } else { |
| compile(code, dir_path, tmp_cubin_path); |
| } |
|
|
| |
| const auto cubin_path = dir_path / "kernel.cubin"; |
| std::filesystem::rename(tmp_cubin_path, cubin_path); |
|
|
| |
| if (get_env<int>("DG_JIT_DUMP_ASM") or get_env<int>("DG_JIT_DUMP_SASS")) { |
| |
| const auto tmp_sass_path = get_tmp_file_path(); |
| disassemble(cubin_path, tmp_sass_path); |
|
|
| |
| std::filesystem::rename(tmp_sass_path, dir_path / "kernel.sass"); |
| } |
|
|
| |
| const auto runtime = kernel_runtime_cache->get(dir_path); |
| DG_HOST_ASSERT(runtime != nullptr); |
| return runtime; |
| } |
|
|
| static void disassemble(const std::filesystem::path &cubin_path, const std::filesystem::path &sass_path) { |
| |
| const auto command = fmt::format("{} --dump-sass {} > {}", cuobjdump_path.c_str(), cubin_path.c_str(), sass_path.c_str()); |
| if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PRINT_COMPILER_COMMAND", 0)) |
| printf("Running cuobjdump command: %s\n", command.c_str()); |
| const auto [return_code, output] = call_external_command(command); |
| if (return_code != 0) { |
| printf("cuobjdump failed: %s\n", output.c_str()); |
| DG_HOST_ASSERT(false and "cuobjdump failed"); |
| } |
| } |
|
|
| virtual void compile(const std::string &code, const std::filesystem::path& dir_path, const std::filesystem::path &cubin_path, const std::optional<std::filesystem::path> &ptx_path = std::nullopt) const = 0; |
| }; |
|
|
| DG_DECLARE_STATIC_VAR_IN_CLASS(Compiler, library_root_path); |
| DG_DECLARE_STATIC_VAR_IN_CLASS(Compiler, library_include_path); |
| DG_DECLARE_STATIC_VAR_IN_CLASS(Compiler, cutlass_include_path); |
| DG_DECLARE_STATIC_VAR_IN_CLASS(Compiler, cuda_home); |
| DG_DECLARE_STATIC_VAR_IN_CLASS(Compiler, library_version); |
| DG_DECLARE_STATIC_VAR_IN_CLASS(Compiler, cuobjdump_path); |
|
|
| class NVCCCompiler final: public Compiler { |
| std::filesystem::path nvcc_path; |
|
|
| std::pair<int, int> get_nvcc_version() const { |
| DG_HOST_ASSERT(std::filesystem::exists(nvcc_path)); |
|
|
| |
| const auto& command = std::string(nvcc_path) + " --version"; |
| const auto& [return_code, output] = call_external_command(command); |
| DG_HOST_ASSERT(return_code == 0); |
|
|
| |
| int major, minor; |
| std::smatch match; |
| DG_HOST_ASSERT(std::regex_search(output, match, std::regex(R"(release (\d+\.\d+))"))); |
| std::sscanf(match[1].str().c_str(), "%d.%d", &major, &minor); |
| DG_HOST_ASSERT((major > 12 or (major == 12 and minor >= 3)) and "NVCC version should be >= 12.3"); |
| if (major == 12 and minor < 9) |
| printf("Warning: please use at least NVCC 12.9 for the best DeepGEMM performance\n"); |
| return {major, minor}; |
| } |
|
|
| public: |
| NVCCCompiler() { |
| |
| nvcc_path = cuda_home / "bin" / "nvcc"; |
| if (const auto& env_nvcc_path = get_env<std::string>("DG_JIT_NVCC_COMPILER"); not env_nvcc_path.empty()) |
| nvcc_path = env_nvcc_path; |
| const auto& [nvcc_major, nvcc_minor] = get_nvcc_version(); |
| signature = fmt::format("NVCC{}.{}", nvcc_major, nvcc_minor); |
|
|
| |
| |
| const auto& arch = device_runtime->get_arch(false, nvcc_major > 12 or nvcc_minor >= 9); |
| std::string include_flags = fmt::format("-I{}", library_include_path.c_str()); |
| if (!cutlass_include_path.empty()) |
| include_flags += fmt::format(" -I{}", cutlass_include_path.c_str()); |
| flags = fmt::format("{} {} --gpu-architecture=sm_{} " |
| "--compiler-options=-fPIC,-O3,-fconcepts,-Wno-deprecated-declarations,-Wno-abi " |
| "-O3 --expt-relaxed-constexpr --expt-extended-lambda", |
| flags, include_flags, arch); |
| } |
|
|
| void compile(const std::string &code, const std::filesystem::path& dir_path, |
| const std::filesystem::path &cubin_path, |
| const std::optional<std::filesystem::path> &ptx_path) const override { |
| |
| const auto& code_path = dir_path / "kernel.cu"; |
| put(code_path, code); |
|
|
| |
| const auto& command = fmt::format("{} {} -cubin -o {} {}", nvcc_path.c_str(), code_path.c_str(), cubin_path.c_str(), flags); |
| if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PRINT_COMPILER_COMMAND", 0)) |
| printf("Running NVCC command: %s\n", command.c_str()); |
| const auto& [return_code, output] = call_external_command(command); |
| if (return_code != 0) { |
| printf("NVCC compilation failed: %s\n", output.c_str()); |
| DG_HOST_ASSERT(false and "NVCC compilation failed"); |
| } |
|
|
| |
| if (ptx_path.has_value()) { |
| const auto ptx_command = fmt::format("{} {} -ptx -o {} {}", nvcc_path.c_str(), code_path.c_str(), ptx_path->c_str(), flags); |
| if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PRINT_COMPILER_COMMAND", 0)) |
| printf("Running NVCC PTX command: %s\n", ptx_command.c_str()); |
| const auto [ptx_return_code, ptx_output] = call_external_command(ptx_command); |
| if (ptx_return_code != 0) { |
| printf("NVCC PTX compilation failed: %s\n", ptx_output.c_str()); |
| DG_HOST_ASSERT(false and "NVCC PTX compilation failed"); |
| } |
| } |
|
|
| |
| if (get_env("DG_JIT_PTXAS_CHECK", 0)) |
| DG_HOST_ASSERT(not std::regex_search(output, std::regex(R"(Local memory used)"))); |
|
|
| |
| if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PTXAS_VERBOSE", 0)) |
| printf("%s", output.c_str()); |
| } |
| }; |
|
|
| class NVRTCCompiler final: public Compiler { |
| public: |
| NVRTCCompiler() { |
| |
| int major, minor; |
| DG_NVRTC_CHECK(lazy_nvrtcVersion(&major, &minor)); |
| signature = fmt::format("NVRTC{}.{}", major, minor); |
| DG_HOST_ASSERT((major > 12 or (major == 12 and minor >= 3)) and "NVRTC version should be >= 12.3"); |
|
|
| |
| std::string include_dirs; |
| include_dirs += fmt::format("-I{} ", library_include_path.string()); |
| if (!cutlass_include_path.empty()) |
| include_dirs += fmt::format("-I{} ", cutlass_include_path.string()); |
| include_dirs += fmt::format("-I{} ", (cuda_home / "include").string()); |
|
|
| |
| |
| std::string pch_flags; |
| if (major > 12 or minor >= 8) { |
| pch_flags = "--pch "; |
| if (get_env<int>("DG_JIT_DEBUG", 0)) |
| pch_flags += "--pch-verbose=true "; |
| } |
|
|
| |
| |
| const auto& arch = device_runtime->get_arch(false, major > 12 or minor >= 9); |
| flags = fmt::format("{} {}--gpu-architecture=sm_{} -default-device {} --device-int128", |
| flags, include_dirs, arch, pch_flags); |
| } |
|
|
| void compile(const std::string &code, const std::filesystem::path& dir_path, |
| const std::filesystem::path &cubin_path, |
| const std::optional<std::filesystem::path> &ptx_path) const override { |
| |
| const auto& code_path = dir_path / "kernel.cu"; |
| put(code_path, code); |
|
|
| |
| std::istringstream iss(flags); |
| std::vector<std::string> options; |
| std::string option; |
| while (iss >> option) |
| options.push_back(option); |
|
|
| |
| std::vector<const char*> option_cstrs; |
| for (const auto& opt: options) |
| option_cstrs.push_back(opt.c_str()); |
|
|
| |
| if (get_env<int>("DG_JIT_DEBUG", 0) or get_env<int>("DG_JIT_PRINT_COMPILER_COMMAND", 0)) { |
| printf("Compiling JIT runtime with NVRTC options: "); |
| for (const auto& opt: options) |
| printf("%s ", opt.c_str()); |
| printf("\n"); |
| } |
|
|
| |
| nvrtcProgram program; |
| DG_NVRTC_CHECK(lazy_nvrtcCreateProgram(&program, code.c_str(), "kernel.cu", 0, nullptr, nullptr)); |
| const auto& compile_result = lazy_nvrtcCompileProgram(program, static_cast<int>(option_cstrs.size()), option_cstrs.data()); |
|
|
| |
| size_t log_size; |
| DG_NVRTC_CHECK(lazy_nvrtcGetProgramLogSize(program, &log_size)); |
| if (get_env<int>("DG_JIT_DEBUG", 0) or compile_result != NVRTC_SUCCESS) { |
| if (compile_result != NVRTC_SUCCESS) |
| DG_HOST_ASSERT(log_size > 1); |
| if (log_size > 1) { |
| std::string compilation_log(log_size, '\0'); |
| DG_NVRTC_CHECK(lazy_nvrtcGetProgramLog(program, compilation_log.data())); |
| printf("NVRTC log: %s\n", compilation_log.c_str()); |
| } |
| } |
|
|
| if (ptx_path.has_value()) { |
| |
| size_t ptx_size; |
| DG_NVRTC_CHECK(lazy_nvrtcGetPTXSize(program, &ptx_size)); |
| std::string ptx_data(ptx_size, '\0'); |
| DG_NVRTC_CHECK(lazy_nvrtcGetPTX(program, ptx_data.data())); |
|
|
| |
| put(ptx_path.value(), ptx_data); |
| } |
|
|
| |
| size_t cubin_size; |
| DG_NVRTC_CHECK(lazy_nvrtcGetCUBINSize(program, &cubin_size)); |
| std::string cubin_data(cubin_size, '\0'); |
| DG_NVRTC_CHECK(lazy_nvrtcGetCUBIN(program, cubin_data.data())); |
|
|
| |
| put(cubin_path, cubin_data); |
|
|
| |
| DG_NVRTC_CHECK(lazy_nvrtcDestroyProgram(&program)); |
| } |
| }; |
|
|
| static auto compiler = LazyInit<Compiler>([]() -> std::shared_ptr<Compiler> { |
| if (get_env<int>("DG_JIT_USE_NVRTC", 0)) { |
| return std::make_shared<NVRTCCompiler>(); |
| } else { |
| return std::make_shared<NVCCCompiler>(); |
| } |
| }); |
|
|
| } |
|
|