| import torch.multiprocessing as mp | |
| if mp.current_process().name == "MainProcess": | |
| import yaml | |
| import os | |
| from pathlib import Path | |
| config_path = Path(Path(__file__).parent.parent.parent.resolve(), "config.yaml") | |
| if os.path.exists(config_path): | |
| config = yaml.load(open(config_path, "r"), Loader=yaml.FullLoader) | |
| ops_backend = config["ops_backend"] | |
| else: | |
| ops_backend = "taichi" | |
| assert ops_backend in ["taichi", "cupy"] | |
| if ops_backend == "taichi": | |
| from .taichi_ops import softsplat, ModuleSoftsplat, FunctionSoftsplat, softsplat_func, costvol_func, sepconv_func, init, batch_edt, FunctionAdaCoF, ModuleCorrelation, FunctionCorrelation, _FunctionCorrelation | |
| else: | |
| from .cupy_ops import softsplat, ModuleSoftsplat, FunctionSoftsplat, softsplat_func, costvol_func, sepconv_func, init, batch_edt, FunctionAdaCoF, ModuleCorrelation, FunctionCorrelation, _FunctionCorrelation | |