| include(FetchContent) |
|
|
| |
| FetchContent_Declare( |
| repo-flashmla |
| GIT_REPOSITORY https://github.com/sgl-project/FlashMLA |
| GIT_TAG be055fb7df0090fde45f08e9cb5b8b4c0272da73 |
| GIT_SHALLOW OFF |
| ) |
| FetchContent_Populate(repo-flashmla) |
|
|
| set(FLASHMLA_CUDA_FLAGS |
| "--expt-relaxed-constexpr" |
| "--expt-extended-lambda" |
| "--use_fast_math" |
|
|
| "-Xcudafe=--diag_suppress=177" |
| ) |
|
|
| |
| |
| |
| if(${CUDA_VERSION} VERSION_GREATER 12.4) |
| list(APPEND FLASHMLA_CUDA_FLAGS |
| "-gencode=arch=compute_90a,code=sm_90a" |
| ) |
| endif() |
| if(${CUDA_VERSION} VERSION_GREATER 12.8) |
| list(APPEND FLASHMLA_CUDA_FLAGS |
| "-gencode=arch=compute_100a,code=sm_100a" |
| ) |
| endif() |
| if(${CUDA_VERSION} VERSION_GREATER_EQUAL "13.0") |
| |
| |
|
|
| |
| set(FLASHMLA_UTILS_FILE "${repo-flashmla_SOURCE_DIR}/csrc/flashmla_utils.h") |
| file(READ "${FLASHMLA_UTILS_FILE}" FLASHMLA_UTILS_CONTENT) |
| string(REPLACE |
| "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 1000) |
| #define IS_SM100 1" |
| "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) && (__CUDA_ARCH__ < 1100) |
| #define IS_SM100 1" |
| FLASHMLA_UTILS_CONTENT "${FLASHMLA_UTILS_CONTENT}") |
| file(WRITE "${FLASHMLA_UTILS_FILE}" "${FLASHMLA_UTILS_CONTENT}") |
| message(STATUS "Patched flashmla_utils.h for SM103a support") |
|
|
| |
| |
| |
| set(CUTLASS_CONFIG_FILE "${repo-flashmla_SOURCE_DIR}/csrc/cutlass/include/cutlass/arch/config.h") |
| file(READ "${CUTLASS_CONFIG_FILE}" CUTLASS_CONFIG_CONTENT) |
| string(FIND "${CUTLASS_CONFIG_CONTENT}" "SM103" SM103_FOUND) |
| if(SM103_FOUND EQUAL -1) |
| string(REPLACE |
| "// SM101 and SM101a" |
| "// SM103 and SM103a |
| #if !CUTLASS_CLANG_CUDA && (__CUDACC_VER_MAJOR__ >= 13) |
| #define CUTLASS_ARCH_MMA_SM103_SUPPORTED 1 |
| #if (!defined(CUTLASS_ARCH_MMA_SM103_ENABLED) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 1030) |
| #define CUTLASS_ARCH_MMA_SM103_ENABLED 1 |
| #if !defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) |
| #define CUTLASS_ARCH_MMA_SM100A_ENABLED 1 |
| #endif |
| #if !defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) |
| #define CUTLASS_ARCH_MMA_SM100F_ENABLED 1 |
| #endif |
| #endif |
| #endif |
| |
| ///////////////////////////////////////////////////////////////////////////////////////////////// |
| |
| // SM101 and SM101a" |
| CUTLASS_CONFIG_CONTENT "${CUTLASS_CONFIG_CONTENT}") |
| file(WRITE "${CUTLASS_CONFIG_FILE}" "${CUTLASS_CONFIG_CONTENT}") |
| message(STATUS "Patched cutlass/arch/config.h for SM103a support") |
| else() |
| message(STATUS "cutlass/arch/config.h already patched for SM103a") |
| endif() |
|
|
| list(APPEND FLASHMLA_CUDA_FLAGS |
| "-gencode=arch=compute_103a,code=sm_103a" |
| ) |
| endif() |
|
|
|
|
| set(FlashMLA_SOURCES |
| "csrc/flashmla_extension.cc" |
| ${repo-flashmla_SOURCE_DIR}/csrc/python_api.cpp |
| ${repo-flashmla_SOURCE_DIR}/csrc/smxx/get_mla_metadata.cu |
| ${repo-flashmla_SOURCE_DIR}/csrc/smxx/mla_combine.cu |
| ${repo-flashmla_SOURCE_DIR}/csrc/sm90/decode/dense/splitkv_mla.cu |
| ${repo-flashmla_SOURCE_DIR}/csrc/sm90/decode/sparse_fp8/splitkv_mla.cu |
| ${repo-flashmla_SOURCE_DIR}/csrc/sm90/prefill/sparse/fwd.cu |
| ${repo-flashmla_SOURCE_DIR}/csrc/sm100/decode/sparse_fp8/splitkv_mla.cu |
| ${repo-flashmla_SOURCE_DIR}/csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu |
| ${repo-flashmla_SOURCE_DIR}/csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cu |
| ${repo-flashmla_SOURCE_DIR}/csrc/sm100/prefill/sparse/fwd.cu |
|
|
| ${repo-flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/dense_fp8_python_api.cpp |
| ${repo-flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/flash_fwd_mla_fp8_sm90.cu |
| ${repo-flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/flash_fwd_mla_metadata.cu |
| ) |
|
|
| Python_add_library(flashmla_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${FlashMLA_SOURCES}) |
| target_compile_options(flashmla_ops PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:${FLASHMLA_CUDA_FLAGS}>) |
| target_include_directories(flashmla_ops PRIVATE |
| ${repo-flashmla_SOURCE_DIR}/csrc |
| ${repo-flashmla_SOURCE_DIR}/csrc/sm90 |
| ${repo-flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/ |
| ${repo-flashmla_SOURCE_DIR}/csrc/cutlass/include |
| ${repo-flashmla_SOURCE_DIR}/csrc/cutlass/tools/util/include |
| ) |
|
|
| target_link_libraries(flashmla_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda) |
|
|
| install(TARGETS flashmla_ops LIBRARY DESTINATION "sgl_kernel") |
|
|
| target_compile_definitions(flashmla_ops PRIVATE) |
|
|