File size: 951 Bytes
61029c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch.multiprocessing as mp
import torch
from .raw_softsplat import worker_interface as raw_softsplat
from .costvol import worker_interface as costvol
from .sepconv import worker_interface as sepconv
from .utils import to_shared_memory, to_device
import taichi as ti
import traceback

def f(child_conn, device: torch.DeviceObjType):
    ti.init(arch=ti.gpu)
    while True:
        op_name, tensors = child_conn.recv()
        tensors = to_device(tensors, device)
        try:
            if "softsplat" in op_name:
                result = raw_softsplat(op_name, tensors)
            elif "costvol" in op_name:
                result = costvol(op_name, tensors)
            elif "sepconv" in op_name:
                result = sepconv(op_name, tensors)
            else:
                raise NotImplementedError(op_name)
            child_conn.send(to_shared_memory(result))
        except:
            child_conn.send(traceback.format_exc())