| from setuptools import setup | |
| from torch.utils.cpp_extension import BuildExtension, CUDAExtension | |
| def get_ext_modules(): | |
| return [ | |
| CUDAExtension( | |
| name='_msmv_sampling_cuda', | |
| sources=[ | |
| 'msmv_sampling/msmv_sampling.cpp', | |
| 'msmv_sampling/msmv_sampling_forward.cu', | |
| 'msmv_sampling/msmv_sampling_backward.cu' | |
| ], | |
| include_dirs=['msmv_sampling'] | |
| ) | |
| ] | |
| setup( | |
| name='csrc', | |
| ext_modules=get_ext_modules(), | |
| cmdclass={'build_ext': BuildExtension} | |
| ) | |