File size: 568 Bytes
d19bd3e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 | from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
def get_ext_modules():
return [
CUDAExtension(
name='_msmv_sampling_cuda',
sources=[
'msmv_sampling/msmv_sampling.cpp',
'msmv_sampling/msmv_sampling_forward.cu',
'msmv_sampling/msmv_sampling_backward.cu'
],
include_dirs=['msmv_sampling']
)
]
setup(
name='csrc',
ext_modules=get_ext_modules(),
cmdclass={'build_ext': BuildExtension}
)
|