|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import onnx |
|
|
import tensorrt as trt |
|
|
from onnx import TensorProto, helper |
|
|
|
|
|
|
|
|
def trt_dtype_to_onnx(dtype): |
|
|
if dtype == trt.float16: |
|
|
return TensorProto.DataType.FLOAT16 |
|
|
if dtype == trt.bfloat16: |
|
|
return TensorProto.DataType.BFLOAT16 |
|
|
elif dtype == trt.float32: |
|
|
return TensorProto.DataType.FLOAT |
|
|
elif dtype == trt.int32: |
|
|
return TensorProto.DataType.INT32 |
|
|
elif dtype == trt.int64: |
|
|
return TensorProto.DataType.INT64 |
|
|
elif dtype == trt.bool: |
|
|
return TensorProto.DataType.BOOL |
|
|
else: |
|
|
raise TypeError("%s is not supported" % dtype) |
|
|
|
|
|
|
|
|
def to_onnx(network, path, name: str = None): |
|
|
if name is None: |
|
|
name = "debug_graph" |
|
|
inputs = [] |
|
|
for i in range(network.num_inputs): |
|
|
network_input = network.get_input(i) |
|
|
inputs.append( |
|
|
helper.make_tensor_value_info( |
|
|
network_input.name, trt_dtype_to_onnx(network_input.dtype), |
|
|
list(network_input.shape))) |
|
|
|
|
|
outputs = [] |
|
|
for i in range(network.num_outputs): |
|
|
network_output = network.get_output(i) |
|
|
outputs.append( |
|
|
helper.make_tensor_value_info( |
|
|
network_output.name, trt_dtype_to_onnx(network_output.dtype), |
|
|
list(network_output.shape))) |
|
|
|
|
|
nodes = [] |
|
|
for i in range(network.num_layers): |
|
|
layer = network.get_layer(i) |
|
|
layer_inputs = [] |
|
|
for j in range(layer.num_inputs): |
|
|
ipt = layer.get_input(j) |
|
|
if ipt is not None: |
|
|
layer_inputs.append(layer.get_input(j).name) |
|
|
layer_outputs = [ |
|
|
layer.get_output(j).name for j in range(layer.num_outputs) |
|
|
] |
|
|
nodes.append( |
|
|
helper.make_node(str(layer.type), |
|
|
name=layer.name, |
|
|
inputs=layer_inputs, |
|
|
outputs=layer_outputs, |
|
|
domain="com.nvidia")) |
|
|
|
|
|
onnx_model = helper.make_model(helper.make_graph(nodes, |
|
|
name, |
|
|
inputs, |
|
|
outputs, |
|
|
initializer=None), |
|
|
producer_name='NVIDIA') |
|
|
onnx.save(onnx_model, path) |
|
|
|