| | from absl import app, flags, logging |
| | from absl.flags import FLAGS |
| | import tensorflow as tf |
| | physical_devices = tf.config.experimental.list_physical_devices('GPU') |
| | if len(physical_devices) > 0: |
| | tf.config.experimental.set_memory_growth(physical_devices[0], True) |
| | import numpy as np |
| | import cv2 |
| | from tensorflow.python.compiler.tensorrt import trt_convert as trt |
| | import core.utils as utils |
| | from tensorflow.python.saved_model import signature_constants |
| | import os |
| | from tensorflow.compat.v1 import ConfigProto |
| | from tensorflow.compat.v1 import InteractiveSession |
| |
|
| | flags.DEFINE_string('weights', './checkpoints/yolov4-416', 'path to weights file') |
| | flags.DEFINE_string('output', './checkpoints/yolov4-trt-fp16-416', 'path to output') |
| | flags.DEFINE_integer('input_size', 416, 'path to output') |
| | flags.DEFINE_string('quantize_mode', 'float16', 'quantize mode (int8, float16)') |
| | flags.DEFINE_string('dataset', "/media/user/Source/Data/coco_dataset/coco/5k.txt", 'path to dataset') |
| | flags.DEFINE_integer('loop', 8, 'loop') |
| |
|
| | def representative_data_gen(): |
| | fimage = open(FLAGS.dataset).read().split() |
| | batched_input = np.zeros((FLAGS.loop, FLAGS.input_size, FLAGS.input_size, 3), dtype=np.float32) |
| | for input_value in range(FLAGS.loop): |
| | if os.path.exists(fimage[input_value]): |
| | original_image=cv2.imread(fimage[input_value]) |
| | original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB) |
| | image_data = utils.image_preporcess(np.copy(original_image), [FLAGS.input_size, FLAGS.input_size]) |
| | img_in = image_data[np.newaxis, ...].astype(np.float32) |
| | batched_input[input_value, :] = img_in |
| | |
| | print(input_value) |
| | |
| | |
| | else: |
| | continue |
| | batched_input = tf.constant(batched_input) |
| | yield (batched_input,) |
| |
|
| | def save_trt(): |
| |
|
| | if FLAGS.quantize_mode == 'int8': |
| | conversion_params = trt.DEFAULT_TRT_CONVERSION_PARAMS._replace( |
| | precision_mode=trt.TrtPrecisionMode.INT8, |
| | max_workspace_size_bytes=4000000000, |
| | use_calibration=True, |
| | max_batch_size=8) |
| | converter = trt.TrtGraphConverterV2( |
| | input_saved_model_dir=FLAGS.weights, |
| | conversion_params=conversion_params) |
| | converter.convert(calibration_input_fn=representative_data_gen) |
| | elif FLAGS.quantize_mode == 'float16': |
| | conversion_params = trt.DEFAULT_TRT_CONVERSION_PARAMS._replace( |
| | precision_mode=trt.TrtPrecisionMode.FP16, |
| | max_workspace_size_bytes=4000000000, |
| | max_batch_size=8) |
| | converter = trt.TrtGraphConverterV2( |
| | input_saved_model_dir=FLAGS.weights, conversion_params=conversion_params) |
| | converter.convert() |
| | else : |
| | conversion_params = trt.DEFAULT_TRT_CONVERSION_PARAMS._replace( |
| | precision_mode=trt.TrtPrecisionMode.FP32, |
| | max_workspace_size_bytes=4000000000, |
| | max_batch_size=8) |
| | converter = trt.TrtGraphConverterV2( |
| | input_saved_model_dir=FLAGS.weights, conversion_params=conversion_params) |
| | converter.convert() |
| |
|
| | |
| | converter.save(output_saved_model_dir=FLAGS.output) |
| | print('Done Converting to TF-TRT') |
| |
|
| | saved_model_loaded = tf.saved_model.load(FLAGS.output) |
| | graph_func = saved_model_loaded.signatures[ |
| | signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] |
| | trt_graph = graph_func.graph.as_graph_def() |
| | for n in trt_graph.node: |
| | print(n.op) |
| | if n.op == "TRTEngineOp": |
| | print("Node: %s, %s" % (n.op, n.name.replace("/", "_"))) |
| | else: |
| | print("Exclude Node: %s, %s" % (n.op, n.name.replace("/", "_"))) |
| | logging.info("model saved to: {}".format(FLAGS.output)) |
| |
|
| | trt_engine_nodes = len([1 for n in trt_graph.node if str(n.op) == 'TRTEngineOp']) |
| | print("numb. of trt_engine_nodes in TensorRT graph:", trt_engine_nodes) |
| | all_nodes = len([1 for n in trt_graph.node]) |
| | print("numb. of all_nodes in TensorRT graph:", all_nodes) |
| |
|
| | def main(_argv): |
| | config = ConfigProto() |
| | config.gpu_options.allow_growth = True |
| | session = InteractiveSession(config=config) |
| | save_trt() |
| |
|
| | if __name__ == '__main__': |
| | try: |
| | app.run(main) |
| | except SystemExit: |
| | pass |
| |
|
| |
|
| |
|