| import numpy as np |
| import onnx |
| from onnx import shape_inference |
| try: |
| import onnx_graphsurgeon as gs |
| except Exception as e: |
| print('Import onnx_graphsurgeon failure: %s' % e) |
|
|
| import logging |
|
|
| LOGGER = logging.getLogger(__name__) |
|
|
| class RegisterNMS(object): |
| def __init__( |
| self, |
| onnx_model_path: str, |
| precision: str = "fp32", |
| ): |
|
|
| self.graph = gs.import_onnx(onnx.load(onnx_model_path)) |
| assert self.graph |
| LOGGER.info("ONNX graph created successfully") |
| |
| self.graph.fold_constants() |
| self.precision = precision |
| self.batch_size = 1 |
| def infer(self): |
| """ |
| Sanitize the graph by cleaning any unconnected nodes, do a topological resort, |
| and fold constant inputs values. When possible, run shape inference on the |
| ONNX graph to determine tensor shapes. |
| """ |
| for _ in range(3): |
| count_before = len(self.graph.nodes) |
|
|
| self.graph.cleanup().toposort() |
| try: |
| for node in self.graph.nodes: |
| for o in node.outputs: |
| o.shape = None |
| model = gs.export_onnx(self.graph) |
| model = shape_inference.infer_shapes(model) |
| self.graph = gs.import_onnx(model) |
| except Exception as e: |
| LOGGER.info(f"Shape inference could not be performed at this time:\n{e}") |
| try: |
| self.graph.fold_constants(fold_shapes=True) |
| except TypeError as e: |
| LOGGER.error( |
| "This version of ONNX GraphSurgeon does not support folding shapes, " |
| f"please upgrade your onnx_graphsurgeon module. Error:\n{e}" |
| ) |
| raise |
|
|
| count_after = len(self.graph.nodes) |
| if count_before == count_after: |
| |
| break |
|
|
| def save(self, output_path): |
| """ |
| Save the ONNX model to the given location. |
| Args: |
| output_path: Path pointing to the location where to write |
| out the updated ONNX model. |
| """ |
| self.graph.cleanup().toposort() |
| model = gs.export_onnx(self.graph) |
| onnx.save(model, output_path) |
| LOGGER.info(f"Saved ONNX model to {output_path}") |
|
|
| def register_nms( |
| self, |
| *, |
| score_thresh: float = 0.25, |
| nms_thresh: float = 0.45, |
| detections_per_img: int = 100, |
| ): |
| """ |
| Register the ``EfficientNMS_TRT`` plugin node. |
| NMS expects these shapes for its input tensors: |
| - box_net: [batch_size, number_boxes, 4] |
| - class_net: [batch_size, number_boxes, number_labels] |
| Args: |
| score_thresh (float): The scalar threshold for score (low scoring boxes are removed). |
| nms_thresh (float): The scalar threshold for IOU (new boxes that have high IOU |
| overlap with previously selected boxes are removed). |
| detections_per_img (int): Number of best detections to keep after NMS. |
| """ |
|
|
| self.infer() |
| |
| op_inputs = self.graph.outputs |
| op = "EfficientNMS_TRT" |
| attrs = { |
| "plugin_version": "1", |
| "background_class": -1, |
| "max_output_boxes": detections_per_img, |
| "score_threshold": score_thresh, |
| "iou_threshold": nms_thresh, |
| "score_activation": False, |
| "box_coding": 0, |
| } |
|
|
| if self.precision == "fp32": |
| dtype_output = np.float32 |
| elif self.precision == "fp16": |
| dtype_output = np.float16 |
| else: |
| raise NotImplementedError(f"Currently not supports precision: {self.precision}") |
|
|
| |
| output_num_detections = gs.Variable( |
| name="num_dets", |
| dtype=np.int32, |
| shape=[self.batch_size, 1], |
| ) |
| output_boxes = gs.Variable( |
| name="det_boxes", |
| dtype=dtype_output, |
| shape=[self.batch_size, detections_per_img, 4], |
| ) |
| output_scores = gs.Variable( |
| name="det_scores", |
| dtype=dtype_output, |
| shape=[self.batch_size, detections_per_img], |
| ) |
| output_labels = gs.Variable( |
| name="det_classes", |
| dtype=np.int32, |
| shape=[self.batch_size, detections_per_img], |
| ) |
|
|
| op_outputs = [output_num_detections, output_boxes, output_scores, output_labels] |
|
|
| |
| |
| self.graph.layer(op=op, name="batched_nms", inputs=op_inputs, outputs=op_outputs, attrs=attrs) |
| LOGGER.info(f"Created NMS plugin '{op}' with attributes: {attrs}") |
|
|
| self.graph.outputs = op_outputs |
|
|
| self.infer() |
|
|
| def save(self, output_path): |
| """ |
| Save the ONNX model to the given location. |
| Args: |
| output_path: Path pointing to the location where to write |
| out the updated ONNX model. |
| """ |
| self.graph.cleanup().toposort() |
| model = gs.export_onnx(self.graph) |
| onnx.save(model, output_path) |
| LOGGER.info(f"Saved ONNX model to {output_path}") |
|
|