## UNUSED BY KERNEL-BUILDER ## File is retained for reference, but is not currently used in the build process. # import os # from pathlib import Path # from datetime import datetime # import subprocess # from setuptools import setup, find_packages # from torch.utils.cpp_extension import ( # BuildExtension, # CUDAExtension, # IS_WINDOWS, # CUDA_HOME # ) # def is_flag_set(flag: str) -> bool: # return os.getenv(flag, "FALSE").lower() in ["true", "1", "y", "yes"] # def get_features_args(): # features_args = [] # if is_flag_set("FLASH_MLA_DISABLE_FP16"): # features_args.append("-DFLASH_MLA_DISABLE_FP16") # return features_args # def get_arch_flags(): # # Check NVCC Version # # NOTE The "CUDA_HOME" here is not necessarily from the `CUDA_HOME` environment variable. For more details, see `torch/utils/cpp_extension.py` # assert CUDA_HOME is not None, "PyTorch must be compiled with CUDA support" # nvcc_version = subprocess.check_output( # [os.path.join(CUDA_HOME, "bin", "nvcc"), '--version'], stderr=subprocess.STDOUT # ).decode('utf-8') # nvcc_version_number = nvcc_version.split('release ')[1].split(',')[0].strip() # major, minor = map(int, nvcc_version_number.split('.')) # print(f'Compiling using NVCC {major}.{minor}') # DISABLE_SM100 = is_flag_set("FLASH_MLA_DISABLE_SM100") # DISABLE_SM90 = is_flag_set("FLASH_MLA_DISABLE_SM90") # if major < 12 or (major == 12 and minor <= 8): # assert DISABLE_SM100, "sm100 compilation for Flash MLA requires NVCC 12.9 or higher. Please set FLASH_MLA_DISABLE_SM100=1 to disable sm100 compilation, or update your environment." # TODO Implement this # arch_flags = [] # if not DISABLE_SM100: # arch_flags.extend(["-gencode", "arch=compute_100f,code=sm_100f"]) # if not DISABLE_SM90: # arch_flags.extend(["-gencode", "arch=compute_90a,code=sm_90a"]) # return arch_flags # def get_nvcc_thread_args(): # nvcc_threads = os.getenv("NVCC_THREADS") or "32" # return ["--threads", nvcc_threads] # subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"]) # this_dir = os.path.dirname(os.path.abspath(__file__)) # if IS_WINDOWS: # cxx_args = ["/O2", "/std:c++20", "/DNDEBUG", "/W0"] # else: # cxx_args = ["-O3", "-std=c++20", "-DNDEBUG", "-Wno-deprecated-declarations"] # ext_modules = [] # ext_modules.append( # CUDAExtension( # name="flash_mla.cuda", # sources=[ # # API # "csrc/api/api.cpp", # # Misc kernels for decoding # "csrc/smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.cu", # "csrc/smxx/decode/combine/combine.cu", # # sm90 dense decode # "csrc/sm90/decode/dense/instantiations/fp16.cu", # "csrc/sm90/decode/dense/instantiations/bf16.cu", # # sm90 sparse decode # "csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h64.cu", # "csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h128.cu", # "csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h64.cu", # "csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h128.cu", # # sm90 sparse prefill # "csrc/sm90/prefill/sparse/fwd.cu", # "csrc/sm90/prefill/sparse/instantiations/phase1_k512.cu", # "csrc/sm90/prefill/sparse/instantiations/phase1_k512_topklen.cu", # "csrc/sm90/prefill/sparse/instantiations/phase1_k576.cu", # "csrc/sm90/prefill/sparse/instantiations/phase1_k576_topklen.cu", # # sm100 dense prefill & backward # "csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu", # "csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cu", # # sm100 sparse prefill # "csrc/sm100/prefill/sparse/fwd/head64/instantiations/phase1_k512.cu", # "csrc/sm100/prefill/sparse/fwd/head64/instantiations/phase1_k576.cu", # "csrc/sm100/prefill/sparse/fwd/head128/instantiations/phase1_k512.cu", # "csrc/sm100/prefill/sparse/fwd/head128/instantiations/phase1_k576.cu", # "csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/instantiations/phase1_prefill_k512.cu", # # sm100 sparse decode # "csrc/sm100/decode/head64/instantiations/v32.cu", # "csrc/sm100/decode/head64/instantiations/model1.cu", # "csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/instantiations/phase1_decode_k512.cu", # ], # extra_compile_args={ # "cxx": cxx_args + get_features_args(), # "nvcc": [ # "-O3", # "-std=c++20", # "-DNDEBUG", # "-D_USE_MATH_DEFINES", # "-Wno-deprecated-declarations", # "-U__CUDA_NO_HALF_OPERATORS__", # "-U__CUDA_NO_HALF_CONVERSIONS__", # "-U__CUDA_NO_HALF2_OPERATORS__", # "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", # "--expt-relaxed-constexpr", # "--expt-extended-lambda", # "--use_fast_math", # "--ptxas-options=-v,--register-usage-level=10,--warn-on-spills,--warn-on-local-memory-usage,--warn-on-double-precision-use", # "-lineinfo", # "--source-in-ptx", # ] + get_features_args() + get_arch_flags() + get_nvcc_thread_args(), # }, # include_dirs=[ # Path(this_dir) / "csrc", # Path(this_dir) / "csrc" / "kerutils" / "include", # TODO Remove me # Path(this_dir) / "csrc" / "sm90", # Path(this_dir) / "csrc" / "cutlass" / "include", # Path(this_dir) / "csrc" / "cutlass" / "tools" / "util" / "include", # ], # ) # ) # try: # cmd = ['git', 'rev-parse', '--short', 'HEAD'] # rev = '+' + subprocess.check_output(cmd).decode('ascii').rstrip() # except Exception as _: # now = datetime.now() # date_time_str = now.strftime("%Y-%m-%d-%H-%M-%S") # rev = '+' + date_time_str # setup( # name="flash_mla", # version="1.0.0" + rev, # packages=find_packages(include=['flash_mla']), # ext_modules=ext_modules, # cmdclass={"build_ext": BuildExtension}, # )