sbb-binarization-onnx / fix_onnx.py
nathansut1's picture
Upload fix_onnx.py with huggingface_hub
f00cc10 verified
#!/usr/bin/env python3
"""Fix ONNX model for TensorRT compatibility.
Takes the raw tf2onnx output (model.onnx) and produces a TRT-optimized
model (model_convtranspose.onnx) in three steps:
Step 2A: Fix Reshape node with -2048 batch dim -> -1
Step 2B: Fix Resize nodes with TF-specific attributes
Step 2C: Validate model
Step 3: Replace Resize (nearest 2x upsample) -> ConvTranspose
Usage:
python3 fix_onnx.py <input.onnx> <output.onnx>
Requires: pip install onnx numpy
"""
import sys
import onnx
from onnx import numpy_helper, shape_inference, helper
import numpy as np
if len(sys.argv) < 3:
print(f"Usage: {sys.argv[0]} <input.onnx> <output.onnx>")
sys.exit(1)
INPUT_MODEL = sys.argv[1]
OUTPUT_MODEL = sys.argv[2]
print(f"Loading model from {INPUT_MODEL}...")
model = onnx.load(INPUT_MODEL)
print(f" opset: {[o.version for o in model.opset_import]}")
print(f" nodes: {len(model.graph.node)}")
fixes_applied = []
# ── Step 2A: Fix Reshape with -2048 ──────────────────────────────────────────
print("\n=== Step 2A: Checking Reshape nodes for -2048 ===")
for node in model.graph.node:
if node.op_type == "Reshape":
for init in model.graph.initializer:
if init.name == node.input[1]:
shape = numpy_helper.to_array(init).copy()
if -2048 in shape:
print(f" FOUND -2048 in initializer '{init.name}': {shape}")
shape[shape == -2048] = -1
new_init = numpy_helper.from_array(shape, init.name)
idx = list(model.graph.initializer).index(init)
model.graph.initializer.remove(init)
model.graph.initializer.insert(idx, new_init)
fixes_applied.append(f"Reshape: fixed -2048 -> -1 in '{init.name}'")
print(f" FIXED -> {numpy_helper.to_array(new_init)}")
# ── Step 2B: Fix Resize node attributes ──────────────────────────────────────
print("\n=== Step 2B: Checking Resize node attributes ===")
resize_count = 0
for node in model.graph.node:
if node.op_type == "Resize":
resize_count += 1
node_fixes = []
for attr in node.attribute:
if attr.name == "nearest_mode" and attr.s == b"floor":
old_val = attr.s.decode()
attr.s = b"round_prefer_floor"
node_fixes.append(f"nearest_mode: {old_val} -> round_prefer_floor")
if attr.name == "coordinate_transformation_mode" and attr.s == b"tf_half_pixel_for_nn":
old_val = attr.s.decode()
attr.s = b"half_pixel"
node_fixes.append(f"coordinate_transformation_mode: {old_val} -> half_pixel")
if node_fixes:
print(f" Resize '{node.name}': {', '.join(node_fixes)}")
fixes_applied.extend([f"Resize '{node.name}': {f}" for f in node_fixes])
else:
attrs = {a.name: a.s.decode() if a.type == 3 else a for a in node.attribute}
print(f" Resize '{node.name}': OK (attrs: {attrs})")
print(f" Total Resize nodes: {resize_count}")
# ── Step 2C: Scan for other issues ───────────────────────────────────────────
print("\n=== Step 2C: Scanning for other potential TRT issues ===")
for node in model.graph.node:
if node.op_type == "Reshape":
for init in model.graph.initializer:
if init.name == node.input[1]:
shape = numpy_helper.to_array(init)
negatives = shape[shape < -1]
if len(negatives) > 0:
print(f" WARNING: Reshape '{init.name}' still has negative values: {shape}")
for node in model.graph.node:
if node.op_type == "LayerNormalization":
for attr in node.attribute:
if attr.name == "stash_type" and attr.i != 1:
print(f" WARNING: LayerNormalization '{node.name}' has stash_type={attr.i}")
op_counts = {}
for node in model.graph.node:
op_counts[node.op_type] = op_counts.get(node.op_type, 0) + 1
print(f" Op type distribution:")
for op, count in sorted(op_counts.items()):
print(f" {op}: {count}")
# ── Step 3: Replace Resize (nearest 2x) -> ConvTranspose ────────────────────
#
# Resize nodes doing nearest-neighbor 2x upsampling cause TRT to split the
# model into 8+ subgraphs, with GPU<->CPU copies at each boundary.
# Replacing them with depthwise ConvTranspose (group=channels, 2x2 kernel
# of all ones, stride 2) produces identical output but TRT compiles it as
# a single subgraph.
print("\n=== Step 3: Replacing Resize nodes with ConvTranspose ===")
# Run shape inference so we know tensor dimensions
print(" Running shape inference...")
try:
model = shape_inference.infer_shapes(model)
print(" Shape inference: OK")
except Exception as e:
print(f" Shape inference WARNING: {e}")
# Build tensor shape lookup
tensor_shapes = {}
for vi in list(model.graph.value_info) + list(model.graph.input) + list(model.graph.output):
if vi.type.tensor_type.HasField('shape'):
dims = [d.dim_value if d.dim_value > 0 else -1
for d in vi.type.tensor_type.shape.dim]
tensor_shapes[vi.name] = dims
resize_nodes = [n for n in model.graph.node if n.op_type == "Resize"]
replaced = 0
for resize_node in resize_nodes:
# Only replace nearest-mode upsampling
mode = None
for attr in resize_node.attribute:
if attr.name == "mode":
mode = attr.s.decode()
if mode != "nearest":
print(f" Skipping '{resize_node.name}' (mode={mode}, not nearest)")
continue
# Get channel count from input shape
# After shape inference, tensor may be NHWC [B,H,W,C] or NCHW [B,C,H,W].
# Detect format: if dim[2]==dim[3] it's NCHW (square spatial), channels=dim[1].
input_name = resize_node.input[0]
input_shape = tensor_shapes.get(input_name)
if not input_shape or len(input_shape) != 4:
print(f" Skipping '{resize_node.name}' (can't determine input shape)")
continue
if input_shape[2] == input_shape[3]:
channels = input_shape[1] # NCHW
else:
channels = input_shape[3] # NHWC
if channels <= 0:
print(f" Skipping '{resize_node.name}' (dynamic channels)")
continue
# Create all-ones kernel: [channels, 1, 2, 2] for depthwise ConvTranspose
kernel_name = f"{resize_node.name}_kernel"
kernel_data = np.ones((channels, 1, 2, 2), dtype=np.float32)
model.graph.initializer.append(numpy_helper.from_array(kernel_data, kernel_name))
# Create ConvTranspose node (same input/output tensor names)
ct_node = helper.make_node(
"ConvTranspose",
inputs=[input_name, kernel_name],
outputs=list(resize_node.output),
name=f"{resize_node.name}_ConvTranspose",
kernel_shape=[2, 2],
strides=[2, 2],
pads=[0, 0, 0, 0],
group=channels,
)
# Swap in-place
idx = list(model.graph.node).index(resize_node)
model.graph.node.remove(resize_node)
model.graph.node.insert(idx, ct_node)
replaced += 1
fixes_applied.append(f"Resize '{resize_node.name}' -> ConvTranspose (group={channels})")
print(f" Replaced '{resize_node.name}' -> ConvTranspose (channels={channels})")
print(f" Replaced {replaced} Resize nodes")
# ── Final validation ─────────────────────────────────────────────────────────
print("\n=== Final validation ===")
try:
model = shape_inference.infer_shapes(model)
print(" Shape inference: OK")
except Exception as e:
print(f" Shape inference WARNING: {e}")
try:
onnx.checker.check_model(model)
print(" Model validation: PASSED")
except Exception as e:
print(f" Model validation WARNING: {e}")
try:
onnx.checker.check_model(model, full_check=False)
print(" Model validation (relaxed): PASSED")
except Exception as e2:
print(f" Model validation (relaxed) FAILED: {e2}")
# ── Save ─────────────────────────────────────────────────────────────────────
print(f"\n=== Saving to {OUTPUT_MODEL} ===")
onnx.save(model, OUTPUT_MODEL)
print(f" Saved. Total fixes applied: {len(fixes_applied)}")
for f in fixes_applied:
print(f" - {f}")
remaining_resize = sum(1 for n in model.graph.node if n.op_type == "Resize")
total_ct = sum(1 for n in model.graph.node if n.op_type == "ConvTranspose")
print(f"\n Final: {len(model.graph.node)} nodes, {remaining_resize} Resize, {total_ct} ConvTranspose")
print("\nDone.")