Add onnx2engine.py
Browse files- README.md +4 -1
- onnx2engine.py +96 -0
README.md
CHANGED
|
@@ -3,4 +3,7 @@ license: apache-2.0
|
|
| 3 |
---
|
| 4 |
|
| 5 |
This project contains the onnx and tensorrt model files converted from the chatglm-6b model.
|
| 6 |
-
The infer scripts for onnx and tensorrt will be refined later
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
---
|
| 4 |
|
| 5 |
This project contains the onnx and tensorrt model files converted from the chatglm-6b model.
|
| 6 |
+
The infer scripts for onnx and tensorrt will be refined later
|
| 7 |
+
|
| 8 |
+
onnx2engine.py used to convert onnx into tensorrt engine, batch is now 1, can be modified
|
| 9 |
+
according to their own video memory into dynamic batch
|
onnx2engine.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorrt as trt
|
| 2 |
+
from itertools import tee
|
| 3 |
+
|
| 4 |
+
from polygraphy.backend.trt import (
|
| 5 |
+
network_from_onnx_path,
|
| 6 |
+
engine_from_network,
|
| 7 |
+
save_engine,
|
| 8 |
+
Profile,
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
from polygraphy.backend.trt import CreateConfig
|
| 12 |
+
from tensorrt import PreviewFeature, MemoryPoolType
|
| 13 |
+
|
| 14 |
+
batch_size = 1
|
| 15 |
+
max_length = 2048
|
| 16 |
+
opt_length = max_length // 2
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
profiles = [Profile().add(
|
| 20 |
+
"input_ids",
|
| 21 |
+
min=(batch_size, 1),
|
| 22 |
+
opt=(batch_size, opt_length), # Optimized based on the inputs.
|
| 23 |
+
max=(batch_size, max_length),
|
| 24 |
+
).add(
|
| 25 |
+
"position_ids",
|
| 26 |
+
min=(batch_size, 2,1),
|
| 27 |
+
opt=(batch_size, 2, opt_length), # Optimized based on the inputs.
|
| 28 |
+
max=(batch_size, 2,max_length),
|
| 29 |
+
).add(
|
| 30 |
+
"attention_mask",
|
| 31 |
+
min=(batch_size, 1,1,1),
|
| 32 |
+
opt=(batch_size, 1,opt_length,opt_length), # Optimized based on the inputs.
|
| 33 |
+
max=(batch_size, 1,max_length,max_length),
|
| 34 |
+
)]
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def get_network_definition(network_definition):
|
| 41 |
+
def pairwise(iterable):
|
| 42 |
+
a, b = tee(iterable)
|
| 43 |
+
next(b, None)
|
| 44 |
+
return zip(a, b)
|
| 45 |
+
|
| 46 |
+
indices = list(range(0, network_definition[1].num_layers))
|
| 47 |
+
for i, i_next in pairwise(indices):
|
| 48 |
+
l = network_definition[1].get_layer(i)
|
| 49 |
+
l_next = network_definition[1].get_layer(i_next)
|
| 50 |
+
|
| 51 |
+
if not all([l.get_output(i).is_execution_tensor for i in range(l.num_outputs)]):
|
| 52 |
+
continue
|
| 53 |
+
|
| 54 |
+
if l.get_output_type(0) != trt.float32:
|
| 55 |
+
continue
|
| 56 |
+
|
| 57 |
+
if l.type == trt.LayerType.ELEMENTWISE and l_next.type == trt.LayerType.REDUCE:
|
| 58 |
+
l.__class__ = getattr(trt, "IElementWiseLayer")
|
| 59 |
+
if l.op == trt.ElementWiseOperation.POW:
|
| 60 |
+
l.precision = trt.float32
|
| 61 |
+
l.set_output_type(0, trt.float32)
|
| 62 |
+
|
| 63 |
+
l_next.precision = trt.float32
|
| 64 |
+
l_next.set_output_type(0, trt.float32)
|
| 65 |
+
|
| 66 |
+
return network_definition
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
input_fpath = "./model6b_onnx_pkv/model.onnx"
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
preview_features = [PreviewFeature.FASTER_DYNAMIC_SHAPES_0805]
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
trt_inference_config = CreateConfig(
|
| 77 |
+
fp16=True,
|
| 78 |
+
memory_pool_limits = {MemoryPoolType.WORKSPACE: 2048 * 1024 * 1024},
|
| 79 |
+
profiles=profiles,
|
| 80 |
+
precision_constraints=("obey"),
|
| 81 |
+
preview_features=preview_features
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
onnx_network = network_from_onnx_path(input_fpath)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
network_definition = get_network_definition(onnx_network)
|
| 89 |
+
print(network_definition)
|
| 90 |
+
print(trt_inference_config)
|
| 91 |
+
|
| 92 |
+
trt_engine = engine_from_network(network_definition, trt_inference_config)
|
| 93 |
+
print(trt_engine)
|
| 94 |
+
|
| 95 |
+
output_fpath = "./model6b_trt_pkv/out.engine"
|
| 96 |
+
save_engine(trt_engine, output_fpath)
|