| |
| """Fix ONNX model for TensorRT compatibility. |
| |
| Takes the raw tf2onnx output (model.onnx) and produces a TRT-optimized |
| model (model_convtranspose.onnx) in three steps: |
| |
| Step 2A: Fix Reshape node with -2048 batch dim -> -1 |
| Step 2B: Fix Resize nodes with TF-specific attributes |
| Step 2C: Validate model |
| Step 3: Replace Resize (nearest 2x upsample) -> ConvTranspose |
| |
| Usage: |
| python3 fix_onnx.py <input.onnx> <output.onnx> |
| |
| Requires: pip install onnx numpy |
| """ |
|
|
| import sys |
| import onnx |
| from onnx import numpy_helper, shape_inference, helper |
| import numpy as np |
|
|
| if len(sys.argv) < 3: |
| print(f"Usage: {sys.argv[0]} <input.onnx> <output.onnx>") |
| sys.exit(1) |
|
|
| INPUT_MODEL = sys.argv[1] |
| OUTPUT_MODEL = sys.argv[2] |
|
|
| print(f"Loading model from {INPUT_MODEL}...") |
| model = onnx.load(INPUT_MODEL) |
| print(f" opset: {[o.version for o in model.opset_import]}") |
| print(f" nodes: {len(model.graph.node)}") |
|
|
| fixes_applied = [] |
|
|
| |
| print("\n=== Step 2A: Checking Reshape nodes for -2048 ===") |
| for node in model.graph.node: |
| if node.op_type == "Reshape": |
| for init in model.graph.initializer: |
| if init.name == node.input[1]: |
| shape = numpy_helper.to_array(init).copy() |
| if -2048 in shape: |
| print(f" FOUND -2048 in initializer '{init.name}': {shape}") |
| shape[shape == -2048] = -1 |
| new_init = numpy_helper.from_array(shape, init.name) |
| idx = list(model.graph.initializer).index(init) |
| model.graph.initializer.remove(init) |
| model.graph.initializer.insert(idx, new_init) |
| fixes_applied.append(f"Reshape: fixed -2048 -> -1 in '{init.name}'") |
| print(f" FIXED -> {numpy_helper.to_array(new_init)}") |
|
|
| |
| print("\n=== Step 2B: Checking Resize node attributes ===") |
| resize_count = 0 |
| for node in model.graph.node: |
| if node.op_type == "Resize": |
| resize_count += 1 |
| node_fixes = [] |
| for attr in node.attribute: |
| if attr.name == "nearest_mode" and attr.s == b"floor": |
| old_val = attr.s.decode() |
| attr.s = b"round_prefer_floor" |
| node_fixes.append(f"nearest_mode: {old_val} -> round_prefer_floor") |
| if attr.name == "coordinate_transformation_mode" and attr.s == b"tf_half_pixel_for_nn": |
| old_val = attr.s.decode() |
| attr.s = b"half_pixel" |
| node_fixes.append(f"coordinate_transformation_mode: {old_val} -> half_pixel") |
| if node_fixes: |
| print(f" Resize '{node.name}': {', '.join(node_fixes)}") |
| fixes_applied.extend([f"Resize '{node.name}': {f}" for f in node_fixes]) |
| else: |
| attrs = {a.name: a.s.decode() if a.type == 3 else a for a in node.attribute} |
| print(f" Resize '{node.name}': OK (attrs: {attrs})") |
| print(f" Total Resize nodes: {resize_count}") |
|
|
| |
| print("\n=== Step 2C: Scanning for other potential TRT issues ===") |
|
|
| for node in model.graph.node: |
| if node.op_type == "Reshape": |
| for init in model.graph.initializer: |
| if init.name == node.input[1]: |
| shape = numpy_helper.to_array(init) |
| negatives = shape[shape < -1] |
| if len(negatives) > 0: |
| print(f" WARNING: Reshape '{init.name}' still has negative values: {shape}") |
|
|
| for node in model.graph.node: |
| if node.op_type == "LayerNormalization": |
| for attr in node.attribute: |
| if attr.name == "stash_type" and attr.i != 1: |
| print(f" WARNING: LayerNormalization '{node.name}' has stash_type={attr.i}") |
|
|
| op_counts = {} |
| for node in model.graph.node: |
| op_counts[node.op_type] = op_counts.get(node.op_type, 0) + 1 |
| print(f" Op type distribution:") |
| for op, count in sorted(op_counts.items()): |
| print(f" {op}: {count}") |
|
|
| |
| |
| |
| |
| |
| |
| |
| print("\n=== Step 3: Replacing Resize nodes with ConvTranspose ===") |
|
|
| |
| print(" Running shape inference...") |
| try: |
| model = shape_inference.infer_shapes(model) |
| print(" Shape inference: OK") |
| except Exception as e: |
| print(f" Shape inference WARNING: {e}") |
|
|
| |
| tensor_shapes = {} |
| for vi in list(model.graph.value_info) + list(model.graph.input) + list(model.graph.output): |
| if vi.type.tensor_type.HasField('shape'): |
| dims = [d.dim_value if d.dim_value > 0 else -1 |
| for d in vi.type.tensor_type.shape.dim] |
| tensor_shapes[vi.name] = dims |
|
|
| resize_nodes = [n for n in model.graph.node if n.op_type == "Resize"] |
| replaced = 0 |
|
|
| for resize_node in resize_nodes: |
| |
| mode = None |
| for attr in resize_node.attribute: |
| if attr.name == "mode": |
| mode = attr.s.decode() |
| if mode != "nearest": |
| print(f" Skipping '{resize_node.name}' (mode={mode}, not nearest)") |
| continue |
|
|
| |
| |
| |
| input_name = resize_node.input[0] |
| input_shape = tensor_shapes.get(input_name) |
| if not input_shape or len(input_shape) != 4: |
| print(f" Skipping '{resize_node.name}' (can't determine input shape)") |
| continue |
|
|
| if input_shape[2] == input_shape[3]: |
| channels = input_shape[1] |
| else: |
| channels = input_shape[3] |
| if channels <= 0: |
| print(f" Skipping '{resize_node.name}' (dynamic channels)") |
| continue |
|
|
| |
| kernel_name = f"{resize_node.name}_kernel" |
| kernel_data = np.ones((channels, 1, 2, 2), dtype=np.float32) |
| model.graph.initializer.append(numpy_helper.from_array(kernel_data, kernel_name)) |
|
|
| |
| ct_node = helper.make_node( |
| "ConvTranspose", |
| inputs=[input_name, kernel_name], |
| outputs=list(resize_node.output), |
| name=f"{resize_node.name}_ConvTranspose", |
| kernel_shape=[2, 2], |
| strides=[2, 2], |
| pads=[0, 0, 0, 0], |
| group=channels, |
| ) |
|
|
| |
| idx = list(model.graph.node).index(resize_node) |
| model.graph.node.remove(resize_node) |
| model.graph.node.insert(idx, ct_node) |
|
|
| replaced += 1 |
| fixes_applied.append(f"Resize '{resize_node.name}' -> ConvTranspose (group={channels})") |
| print(f" Replaced '{resize_node.name}' -> ConvTranspose (channels={channels})") |
|
|
| print(f" Replaced {replaced} Resize nodes") |
|
|
| |
| print("\n=== Final validation ===") |
| try: |
| model = shape_inference.infer_shapes(model) |
| print(" Shape inference: OK") |
| except Exception as e: |
| print(f" Shape inference WARNING: {e}") |
|
|
| try: |
| onnx.checker.check_model(model) |
| print(" Model validation: PASSED") |
| except Exception as e: |
| print(f" Model validation WARNING: {e}") |
| try: |
| onnx.checker.check_model(model, full_check=False) |
| print(" Model validation (relaxed): PASSED") |
| except Exception as e2: |
| print(f" Model validation (relaxed) FAILED: {e2}") |
|
|
| |
| print(f"\n=== Saving to {OUTPUT_MODEL} ===") |
| onnx.save(model, OUTPUT_MODEL) |
| print(f" Saved. Total fixes applied: {len(fixes_applied)}") |
| for f in fixes_applied: |
| print(f" - {f}") |
|
|
| remaining_resize = sum(1 for n in model.graph.node if n.op_type == "Resize") |
| total_ct = sum(1 for n in model.graph.node if n.op_type == "ConvTranspose") |
| print(f"\n Final: {len(model.graph.node)} nodes, {remaining_resize} Resize, {total_ct} ConvTranspose") |
| print("\nDone.") |
|
|