|
|
"""Replace OneOCRFeatureExtract with standard Gemm and test the model. |
|
|
This script modifies a model ONNX graph to replace the custom op, then runs inference.""" |
|
|
import onnx |
|
|
from onnx import numpy_helper, helper, TensorProto |
|
|
import numpy as np |
|
|
from pathlib import Path |
|
|
import onnxruntime as ort |
|
|
import copy |
|
|
|
|
|
models_dir = Path("oneocr_extracted/onnx_models") |
|
|
|
|
|
|
|
|
model_path = list(models_dir.glob("model_11_*"))[0] |
|
|
model = onnx.load(str(model_path)) |
|
|
|
|
|
|
|
|
for init in model.graph.initializer: |
|
|
if init.name == "feature/config": |
|
|
blob = bytes(init.string_data[0]) |
|
|
break |
|
|
|
|
|
be_arr = np.frombuffer(blob, dtype='>f4').copy() |
|
|
print(f"Config blob: {len(be_arr)} floats") |
|
|
|
|
|
|
|
|
W_fe = be_arr[:1050].reshape(21, 50).astype(np.float32) |
|
|
b_fe = be_arr[1050:1100].astype(np.float32) |
|
|
metadata = be_arr[1100:] |
|
|
print(f"W: {W_fe.shape}, b: {b_fe.shape}, metadata: {metadata.shape}") |
|
|
print(f"Metadata values: {metadata}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
new_model = copy.deepcopy(model) |
|
|
|
|
|
|
|
|
new_inits = [] |
|
|
for init in new_model.graph.initializer: |
|
|
if init.name == "feature/config": |
|
|
continue |
|
|
new_inits.append(init) |
|
|
|
|
|
|
|
|
W_init = numpy_helper.from_array(W_fe.T, name="fe_weight") |
|
|
b_init = numpy_helper.from_array(b_fe, name="fe_bias") |
|
|
new_inits.append(W_init) |
|
|
new_inits.append(b_init) |
|
|
|
|
|
del new_model.graph.initializer[:] |
|
|
new_model.graph.initializer.extend(new_inits) |
|
|
|
|
|
|
|
|
new_nodes = [] |
|
|
for node in new_model.graph.node: |
|
|
if node.op_type == "OneOCRFeatureExtract": |
|
|
|
|
|
gemm_node = helper.make_node( |
|
|
"Gemm", |
|
|
inputs=["29", "fe_weight", "fe_bias"], |
|
|
outputs=["oneocr_feature"], |
|
|
alpha=1.0, |
|
|
beta=1.0, |
|
|
transB=1, |
|
|
) |
|
|
new_nodes.append(gemm_node) |
|
|
print(f"Replaced OneOCRFeatureExtract with Gemm(29 @ W.T + b)") |
|
|
else: |
|
|
new_nodes.append(node) |
|
|
|
|
|
del new_model.graph.node[:] |
|
|
new_model.graph.node.extend(new_nodes) |
|
|
|
|
|
|
|
|
|
|
|
new_inputs = [] |
|
|
for inp in new_model.graph.input: |
|
|
if inp.name != "feature/config": |
|
|
new_inputs.append(inp) |
|
|
del new_model.graph.input[:] |
|
|
new_model.graph.input.extend(new_inputs) |
|
|
|
|
|
|
|
|
new_opsets = [] |
|
|
for op in new_model.opset_import: |
|
|
if op.domain != "com.microsoft.oneocr": |
|
|
new_opsets.append(op) |
|
|
del new_model.opset_import[:] |
|
|
new_model.opset_import.extend(new_opsets) |
|
|
|
|
|
|
|
|
try: |
|
|
onnx.checker.check_model(new_model) |
|
|
print("Model validation passed!") |
|
|
except Exception as e: |
|
|
print(f"Model validation warning: {e}") |
|
|
|
|
|
|
|
|
modified_path = "temp/model_11_modified.onnx" |
|
|
Path("temp").mkdir(exist_ok=True) |
|
|
onnx.save(new_model, modified_path) |
|
|
print(f"Saved modified model to {modified_path}") |
|
|
|
|
|
|
|
|
print(f"\n--- Testing inference ---") |
|
|
|
|
|
|
|
|
try: |
|
|
sess_orig = ort.InferenceSession(str(model_path)) |
|
|
print("Original model loaded (unexpected!)") |
|
|
except Exception as e: |
|
|
print(f"Original model failed (expected): {str(e)[:100]}") |
|
|
|
|
|
|
|
|
try: |
|
|
sess_mod = ort.InferenceSession(modified_path) |
|
|
print("Modified model loaded successfully!") |
|
|
|
|
|
|
|
|
test_input = np.zeros((1, 21, 1, 1), dtype=np.float32) |
|
|
result = sess_mod.run(None, {"data": test_input}) |
|
|
print(f"Zero input → softmax: {result[0]}") |
|
|
|
|
|
|
|
|
test_input = np.random.randn(1, 21, 1, 1).astype(np.float32) * 0.5 |
|
|
result = sess_mod.run(None, {"data": test_input}) |
|
|
print(f"Random input → softmax: {result[0]}") |
|
|
|
|
|
|
|
|
test_input = np.array([ |
|
|
0.9, 0.1, 0.05, 0.02, 0.01, 0.3, 0.7, 0.6, 0.4, 0.5, |
|
|
0.3, 0.01, 0.02, 0.01, 0.01, 0.01, 0.02, 0.01, 0.01, 0.01, 0.8 |
|
|
], dtype=np.float32).reshape(1, 21, 1, 1) |
|
|
result = sess_mod.run(None, {"data": test_input}) |
|
|
print(f"Typical scores → softmax: {result[0]}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Modified model failed: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
|
|
|
|
|
|
print(f"\n--- Testing with ReLU after feature extraction ---") |
|
|
new_model2 = copy.deepcopy(model) |
|
|
new_inits2 = [] |
|
|
for init in new_model2.graph.initializer: |
|
|
if init.name == "feature/config": |
|
|
continue |
|
|
new_inits2.append(init) |
|
|
new_inits2.append(numpy_helper.from_array(W_fe.T, name="fe_weight")) |
|
|
new_inits2.append(numpy_helper.from_array(b_fe, name="fe_bias")) |
|
|
del new_model2.graph.initializer[:] |
|
|
new_model2.graph.initializer.extend(new_inits2) |
|
|
|
|
|
new_nodes2 = [] |
|
|
for node in new_model2.graph.node: |
|
|
if node.op_type == "OneOCRFeatureExtract": |
|
|
gemm_node = helper.make_node("Gemm", inputs=["29", "fe_weight", "fe_bias"], |
|
|
outputs=["oneocr_feature_pre"], alpha=1.0, beta=1.0, transB=1) |
|
|
relu_node = helper.make_node("Relu", inputs=["oneocr_feature_pre"], outputs=["oneocr_feature"]) |
|
|
new_nodes2.append(gemm_node) |
|
|
new_nodes2.append(relu_node) |
|
|
else: |
|
|
new_nodes2.append(node) |
|
|
del new_model2.graph.node[:] |
|
|
new_model2.graph.node.extend(new_nodes2) |
|
|
|
|
|
new_inputs2 = [inp for inp in new_model2.graph.input if inp.name != "feature/config"] |
|
|
del new_model2.graph.input[:] |
|
|
new_model2.graph.input.extend(new_inputs2) |
|
|
|
|
|
new_opsets2 = [op for op in new_model2.opset_import if op.domain != "com.microsoft.oneocr"] |
|
|
del new_model2.opset_import[:] |
|
|
new_model2.opset_import.extend(new_opsets2) |
|
|
|
|
|
modified_path2 = "temp/model_11_modified_relu.onnx" |
|
|
onnx.save(new_model2, modified_path2) |
|
|
|
|
|
try: |
|
|
sess_mod2 = ort.InferenceSession(modified_path2) |
|
|
test_input = np.zeros((1, 21, 1, 1), dtype=np.float32) |
|
|
result = sess_mod2.run(None, {"data": test_input}) |
|
|
print(f"Zero input (Gemm+ReLU) → softmax: {result[0]}") |
|
|
|
|
|
test_input = np.random.randn(1, 21, 1, 1).astype(np.float32) * 0.5 |
|
|
result = sess_mod2.run(None, {"data": test_input}) |
|
|
print(f"Random input (Gemm+ReLU) → softmax: {result[0]}") |
|
|
except Exception as e: |
|
|
print(f"Failed: {e}") |
|
|
|