|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from .base_model import BaseModel
|
|
|
import torch
|
|
|
from torch.cuda import nvtx
|
|
|
from .predictor import numpy_to_torch_dtype_dict
|
|
|
|
|
|
|
|
|
class StitchingModel(BaseModel):
|
|
|
"""
|
|
|
StitchingModel
|
|
|
"""
|
|
|
|
|
|
def __init__(self, **kwargs):
|
|
|
super(StitchingModel, self).__init__(**kwargs)
|
|
|
|
|
|
def input_process(self, *data):
|
|
|
input = data[0]
|
|
|
return input
|
|
|
|
|
|
def output_process(self, *data):
|
|
|
return data[0]
|
|
|
|
|
|
def predict_trt(self, *data):
|
|
|
nvtx.range_push("forward")
|
|
|
feed_dict = {}
|
|
|
for i, inp in enumerate(self.predictor.inputs):
|
|
|
if isinstance(data[i], torch.Tensor):
|
|
|
feed_dict[inp['name']] = data[i]
|
|
|
else:
|
|
|
feed_dict[inp['name']] = torch.from_numpy(data[i]).to(device=self.device,
|
|
|
dtype=numpy_to_torch_dtype_dict[inp['dtype']])
|
|
|
preds_dict = self.predictor.predict(feed_dict, self.cudaStream)
|
|
|
outs = []
|
|
|
for i, out in enumerate(self.predictor.outputs):
|
|
|
outs.append(preds_dict[out["name"]].cpu().numpy())
|
|
|
nvtx.range_pop()
|
|
|
return outs
|
|
|
|
|
|
def predict(self, *data):
|
|
|
data = self.input_process(*data)
|
|
|
if self.predict_type == "trt":
|
|
|
preds = self.predict_trt(data)
|
|
|
else:
|
|
|
preds = self.predictor.predict(data)
|
|
|
outputs = self.output_process(*preds)
|
|
|
return outputs
|
|
|
|