File size: 6,121 Bytes
7ff4dd0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
# Modified by $@#Anonymous#@$ #20240123
# Copyright (c) 2023, Albert Gu, Tri Dao.
import sys
import warnings
import os
import re
import ast
from pathlib import Path
from packaging.version import parse, Version
import platform
import shutil
from setuptools import setup, find_packages
import subprocess
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
import torch
from torch.utils.cpp_extension import (
BuildExtension,
CppExtension,
CUDAExtension,
CUDA_HOME,
)
# ninja build does not work unless include_dirs are abs path
this_dir = os.path.dirname(os.path.abspath(__file__))
# For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI
FORCE_CXX11_ABI = os.getenv("FORCE_CXX11_ABI", "FALSE") == "TRUE"
def get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output(
[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
)
output = raw_output.split()
release_idx = output.index("release") + 1
bare_metal_version = parse(output[release_idx].split(",")[0])
return raw_output, bare_metal_version
MODES = ["oflex"]
# MODES = ["core", "ndstate", "oflex"]
# MODES = ["core", "ndstate", "oflex", "nrow"]
def get_ext():
cc_flag = []
print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
print("\n\nCUDA_HOME = {}\n\n".format(CUDA_HOME))
# Check, if CUDA11 is installed for compute capability 8.0
multi_threads = True
gencode_sm90 = False
if CUDA_HOME is not None:
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
print("CUDA version: ", bare_metal_version, flush=True)
if bare_metal_version >= Version("11.8"):
gencode_sm90 = True
if bare_metal_version < Version("11.6"):
warnings.warn("CUDA version ealier than 11.6 may leads to performance mismatch.")
if bare_metal_version < Version("11.2"):
multi_threads = False
cc_flag.extend(["-gencode", "arch=compute_70,code=sm_70"])
cc_flag.extend(["-gencode", "arch=compute_80,code=sm_80"])
if gencode_sm90:
cc_flag.extend(["-gencode", "arch=compute_90,code=sm_90"])
if multi_threads:
cc_flag.extend(["--threads", "4"])
# HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
# torch._C._GLIBCXX_USE_CXX11_ABI
# https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920
if FORCE_CXX11_ABI:
torch._C._GLIBCXX_USE_CXX11_ABI = True
sources = dict(
core=[
"csrc/selective_scan/cus/selective_scan.cpp",
"csrc/selective_scan/cus/selective_scan_core_fwd.cu",
"csrc/selective_scan/cus/selective_scan_core_bwd.cu",
],
nrow=[
"csrc/selective_scan/cusnrow/selective_scan_nrow.cpp",
"csrc/selective_scan/cusnrow/selective_scan_core_fwd.cu",
"csrc/selective_scan/cusnrow/selective_scan_core_fwd2.cu",
"csrc/selective_scan/cusnrow/selective_scan_core_fwd3.cu",
"csrc/selective_scan/cusnrow/selective_scan_core_fwd4.cu",
"csrc/selective_scan/cusnrow/selective_scan_core_bwd.cu",
"csrc/selective_scan/cusnrow/selective_scan_core_bwd2.cu",
"csrc/selective_scan/cusnrow/selective_scan_core_bwd3.cu",
"csrc/selective_scan/cusnrow/selective_scan_core_bwd4.cu",
],
ndstate=[
"csrc/selective_scan/cusndstate/selective_scan_ndstate.cpp",
"csrc/selective_scan/cusndstate/selective_scan_core_fwd.cu",
"csrc/selective_scan/cusndstate/selective_scan_core_bwd.cu",
],
oflex=[
"csrc/selective_scan/cusoflex/selective_scan_oflex.cpp",
"csrc/selective_scan/cusoflex/selective_scan_core_fwd.cu",
"csrc/selective_scan/cusoflex/selective_scan_core_bwd.cu",
],
)
names = dict(
core="selective_scan_cuda_core",
nrow="selective_scan_cuda_nrow",
ndstate="selective_scan_cuda_ndstate",
oflex="selective_scan_cuda_oflex",
)
ext_modules = [
CUDAExtension(
name=names.get(MODE, None),
sources=sources.get(MODE, None),
extra_compile_args={
"cxx": ["-O3", "-std=c++17"],
"nvcc": [
"-O3",
"-std=c++17",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--use_fast_math",
"--ptxas-options=-v",
"-lineinfo",
]
+ cc_flag
},
include_dirs=[Path(this_dir) / "csrc" / "selective_scan"],
)
for MODE in MODES
]
return ext_modules
ext_modules = get_ext()
setup(
name="selective_scan",
version="0.0.2",
packages=[],
author="Tri Dao, Albert Gu, $@#Anonymous#@$ ",
author_email="tri@tridao.me, agu@cs.cmu.edu, $@#Anonymous#EMAIL@$",
description="selective scan",
long_description="",
long_description_content_type="text/markdown",
url="https://github.com/state-spaces/mamba",
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: BSD License",
"Operating System :: Unix",
],
ext_modules=ext_modules,
cmdclass={"bdist_wheel": _bdist_wheel, "build_ext": BuildExtension} if ext_modules else {"bdist_wheel": _bdist_wheel,},
python_requires=">=3.7",
install_requires=[
"torch",
"packaging",
"ninja",
"einops",
],
)
|