PEAR / models /modules /base /onnx_model.py
BestWJH's picture
Upload 205 files
2c68f56 verified
import os
import numpy as np
import onnxruntime as ort
class OnnxModel:
def __init__(self, ckpt_fp:str, input_keys=[], output_keys=[], use_gpu=True) -> None:
if use_gpu:
providers = ['CUDAExecutionProvider']
else:
providers = ['CPUExecutionProvider']
self.ort_session = ort.InferenceSession(ckpt_fp, providers=providers)
self.input_keys = input_keys
self.output_keys = output_keys
def __call__(self, *args):
input_kwargs = {k: v for k, v in zip(self.input_keys, args)}
ret = self.ort_session.run(None, input_kwargs)
return {k: v for k, v in zip(self.output_keys, ret)}
def run(self, *args):
return self(*args)