| from setuptools import setup | |
| from torch.utils.cpp_extension import BuildExtension, CUDAExtension | |
| import os | |
| import torch | |
| print("Building gscuda") | |
| # 假设源文件在 gs_cuda 目录下 | |
| file_path = "utils/gs_cuda_dmax" | |
| setup( | |
| name="gscuda", # 模块名 | |
| ext_modules=[ | |
| CUDAExtension( | |
| name="gscuda", # 可以直接作为模块导入 | |
| sources=[ | |
| os.path.join(file_path, "gswrapper.cpp"), | |
| os.path.join(file_path, "gs.cu") | |
| ], | |
| # 设置运行时库路径(可选) | |
| library_dirs=[os.path.join(os.path.dirname(torch.__file__), 'lib')], | |
| ) | |
| ], | |
| cmdclass={ | |
| "build_ext": BuildExtension | |
| }, | |
| ) |