| import torch.multiprocessing as mp | |
| import torch | |
| from .raw_softsplat import worker_interface as raw_softsplat | |
| from .costvol import worker_interface as costvol | |
| from .sepconv import worker_interface as sepconv | |
| from .utils import to_shared_memory, to_device | |
| import taichi as ti | |
| import traceback | |
| def f(child_conn, device: torch.DeviceObjType): | |
| ti.init(arch=ti.gpu) | |
| while True: | |
| op_name, tensors = child_conn.recv() | |
| tensors = to_device(tensors, device) | |
| try: | |
| if "softsplat" in op_name: | |
| result = raw_softsplat(op_name, tensors) | |
| elif "costvol" in op_name: | |
| result = costvol(op_name, tensors) | |
| elif "sepconv" in op_name: | |
| result = sepconv(op_name, tensors) | |
| else: | |
| raise NotImplementedError(op_name) | |
| child_conn.send(to_shared_memory(result)) | |
| except: | |
| child_conn.send(traceback.format_exc()) | |