| # Copyright 2022-present NAVER Corp. | |
| # CC BY-NC-SA 4.0 | |
| # Available only for non-commercial use | |
| from setuptools import setup | |
| from torch import cuda | |
| from torch.utils.cpp_extension import BuildExtension, CUDAExtension | |
| # if you want to compile for all possible CUDA architectures | |
| all_cuda_archs = [] #cuda.get_gencode_flags().replace('compute=','arch=').split() | |
| setup( | |
| name='cuda_deepm', | |
| ext_modules = [ | |
| CUDAExtension( | |
| name = 'cuda_deepm', | |
| sources = ["func.cpp", "kernels.cu"], | |
| extra_compile_args = dict(nvcc=['-O2']+all_cuda_archs, cxx=['-O2']) | |
| ) | |
| ], | |
| cmdclass = { | |
| 'build_ext': BuildExtension | |
| }) | |