oneocr / _archive /replace_custom_op.py
OneOCR Dev
OneOCR - reverse engineering complete, ONNX pipeline 53% match rate
ce847d4
"""Replace OneOCRFeatureExtract with standard Gemm and test the model.
This script modifies a model ONNX graph to replace the custom op, then runs inference."""
import onnx
from onnx import numpy_helper, helper, TensorProto
import numpy as np
from pathlib import Path
import onnxruntime as ort
import copy
models_dir = Path("oneocr_extracted/onnx_models")
# Load model_11
model_path = list(models_dir.glob("model_11_*"))[0]
model = onnx.load(str(model_path))
# Extract the config blob (big-endian float32)
for init in model.graph.initializer:
if init.name == "feature/config":
blob = bytes(init.string_data[0])
break
be_arr = np.frombuffer(blob, dtype='>f4').copy()
print(f"Config blob: {len(be_arr)} floats")
# Extract weight matrix and bias: first 1050 = W[21×50], next 50 = bias, rest = metadata
W_fe = be_arr[:1050].reshape(21, 50).astype(np.float32)
b_fe = be_arr[1050:1100].astype(np.float32)
metadata = be_arr[1100:]
print(f"W: {W_fe.shape}, b: {b_fe.shape}, metadata: {metadata.shape}")
print(f"Metadata values: {metadata}")
# Now build a modified model:
# Replace OneOCRFeatureExtract node with Gemm node
# OneOCRFeatureExtract takes ['29', 'feature/config'] → ['oneocr_feature']
# Replace with: Gemm(['29', 'fe_weight', 'fe_bias']) → ['oneocr_feature']
new_model = copy.deepcopy(model)
# Remove the feature/config initializer and add W, b initializers
new_inits = []
for init in new_model.graph.initializer:
if init.name == "feature/config":
continue
new_inits.append(init)
# Add the extracted weights as initializers
W_init = numpy_helper.from_array(W_fe.T, name="fe_weight") # transB=1: [50, 21]
b_init = numpy_helper.from_array(b_fe, name="fe_bias")
new_inits.append(W_init)
new_inits.append(b_init)
del new_model.graph.initializer[:]
new_model.graph.initializer.extend(new_inits)
# Replace the OneOCRFeatureExtract node with Gemm
new_nodes = []
for node in new_model.graph.node:
if node.op_type == "OneOCRFeatureExtract":
# Input '29' has shape [1, 21], output 'oneocr_feature' should be [1, 50]
gemm_node = helper.make_node(
"Gemm",
inputs=["29", "fe_weight", "fe_bias"],
outputs=["oneocr_feature"],
alpha=1.0,
beta=1.0,
transB=1,
)
new_nodes.append(gemm_node)
print(f"Replaced OneOCRFeatureExtract with Gemm(29 @ W.T + b)")
else:
new_nodes.append(node)
del new_model.graph.node[:]
new_model.graph.node.extend(new_nodes)
# Also need to handle the input value_infos for the new weights
# Remove feature/config from graph inputs if present
new_inputs = []
for inp in new_model.graph.input:
if inp.name != "feature/config":
new_inputs.append(inp)
del new_model.graph.input[:]
new_model.graph.input.extend(new_inputs)
# Fix opset — remove com.microsoft.oneocr domain
new_opsets = []
for op in new_model.opset_import:
if op.domain != "com.microsoft.oneocr":
new_opsets.append(op)
del new_model.opset_import[:]
new_model.opset_import.extend(new_opsets)
# Validate
try:
onnx.checker.check_model(new_model)
print("Model validation passed!")
except Exception as e:
print(f"Model validation warning: {e}")
# Save modified model
modified_path = "temp/model_11_modified.onnx"
Path("temp").mkdir(exist_ok=True)
onnx.save(new_model, modified_path)
print(f"Saved modified model to {modified_path}")
# Try to run inference
print(f"\n--- Testing inference ---")
# Test with original model first (will fail due to custom op)
try:
sess_orig = ort.InferenceSession(str(model_path))
print("Original model loaded (unexpected!)")
except Exception as e:
print(f"Original model failed (expected): {str(e)[:100]}")
# Test with modified model
try:
sess_mod = ort.InferenceSession(modified_path)
print("Modified model loaded successfully!")
# Run with test input
test_input = np.zeros((1, 21, 1, 1), dtype=np.float32)
result = sess_mod.run(None, {"data": test_input})
print(f"Zero input → softmax: {result[0]}")
# Random input
test_input = np.random.randn(1, 21, 1, 1).astype(np.float32) * 0.5
result = sess_mod.run(None, {"data": test_input})
print(f"Random input → softmax: {result[0]}")
# Typical CTC features (normalized scores)
test_input = np.array([
0.9, 0.1, 0.05, 0.02, 0.01, 0.3, 0.7, 0.6, 0.4, 0.5,
0.3, 0.01, 0.02, 0.01, 0.01, 0.01, 0.02, 0.01, 0.01, 0.01, 0.8
], dtype=np.float32).reshape(1, 21, 1, 1)
result = sess_mod.run(None, {"data": test_input})
print(f"Typical scores → softmax: {result[0]}")
except Exception as e:
print(f"Modified model failed: {e}")
import traceback
traceback.print_exc()
# Also try with ReLU after Gemm (maybe the custom op includes activation)
print(f"\n--- Testing with ReLU after feature extraction ---")
new_model2 = copy.deepcopy(model)
new_inits2 = []
for init in new_model2.graph.initializer:
if init.name == "feature/config":
continue
new_inits2.append(init)
new_inits2.append(numpy_helper.from_array(W_fe.T, name="fe_weight"))
new_inits2.append(numpy_helper.from_array(b_fe, name="fe_bias"))
del new_model2.graph.initializer[:]
new_model2.graph.initializer.extend(new_inits2)
new_nodes2 = []
for node in new_model2.graph.node:
if node.op_type == "OneOCRFeatureExtract":
gemm_node = helper.make_node("Gemm", inputs=["29", "fe_weight", "fe_bias"],
outputs=["oneocr_feature_pre"], alpha=1.0, beta=1.0, transB=1)
relu_node = helper.make_node("Relu", inputs=["oneocr_feature_pre"], outputs=["oneocr_feature"])
new_nodes2.append(gemm_node)
new_nodes2.append(relu_node)
else:
new_nodes2.append(node)
del new_model2.graph.node[:]
new_model2.graph.node.extend(new_nodes2)
new_inputs2 = [inp for inp in new_model2.graph.input if inp.name != "feature/config"]
del new_model2.graph.input[:]
new_model2.graph.input.extend(new_inputs2)
new_opsets2 = [op for op in new_model2.opset_import if op.domain != "com.microsoft.oneocr"]
del new_model2.opset_import[:]
new_model2.opset_import.extend(new_opsets2)
modified_path2 = "temp/model_11_modified_relu.onnx"
onnx.save(new_model2, modified_path2)
try:
sess_mod2 = ort.InferenceSession(modified_path2)
test_input = np.zeros((1, 21, 1, 1), dtype=np.float32)
result = sess_mod2.run(None, {"data": test_input})
print(f"Zero input (Gemm+ReLU) → softmax: {result[0]}")
test_input = np.random.randn(1, 21, 1, 1).astype(np.float32) * 0.5
result = sess_mod2.run(None, {"data": test_input})
print(f"Random input (Gemm+ReLU) → softmax: {result[0]}")
except Exception as e:
print(f"Failed: {e}")