| | """Replace OneOCRFeatureExtract with standard Gemm and test the model. |
| | This script modifies a model ONNX graph to replace the custom op, then runs inference.""" |
| | import onnx |
| | from onnx import numpy_helper, helper, TensorProto |
| | import numpy as np |
| | from pathlib import Path |
| | import onnxruntime as ort |
| | import copy |
| |
|
| | models_dir = Path("oneocr_extracted/onnx_models") |
| |
|
| | |
| | model_path = list(models_dir.glob("model_11_*"))[0] |
| | model = onnx.load(str(model_path)) |
| |
|
| | |
| | for init in model.graph.initializer: |
| | if init.name == "feature/config": |
| | blob = bytes(init.string_data[0]) |
| | break |
| |
|
| | be_arr = np.frombuffer(blob, dtype='>f4').copy() |
| | print(f"Config blob: {len(be_arr)} floats") |
| |
|
| | |
| | W_fe = be_arr[:1050].reshape(21, 50).astype(np.float32) |
| | b_fe = be_arr[1050:1100].astype(np.float32) |
| | metadata = be_arr[1100:] |
| | print(f"W: {W_fe.shape}, b: {b_fe.shape}, metadata: {metadata.shape}") |
| | print(f"Metadata values: {metadata}") |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | new_model = copy.deepcopy(model) |
| |
|
| | |
| | new_inits = [] |
| | for init in new_model.graph.initializer: |
| | if init.name == "feature/config": |
| | continue |
| | new_inits.append(init) |
| |
|
| | |
| | W_init = numpy_helper.from_array(W_fe.T, name="fe_weight") |
| | b_init = numpy_helper.from_array(b_fe, name="fe_bias") |
| | new_inits.append(W_init) |
| | new_inits.append(b_init) |
| |
|
| | del new_model.graph.initializer[:] |
| | new_model.graph.initializer.extend(new_inits) |
| |
|
| | |
| | new_nodes = [] |
| | for node in new_model.graph.node: |
| | if node.op_type == "OneOCRFeatureExtract": |
| | |
| | gemm_node = helper.make_node( |
| | "Gemm", |
| | inputs=["29", "fe_weight", "fe_bias"], |
| | outputs=["oneocr_feature"], |
| | alpha=1.0, |
| | beta=1.0, |
| | transB=1, |
| | ) |
| | new_nodes.append(gemm_node) |
| | print(f"Replaced OneOCRFeatureExtract with Gemm(29 @ W.T + b)") |
| | else: |
| | new_nodes.append(node) |
| |
|
| | del new_model.graph.node[:] |
| | new_model.graph.node.extend(new_nodes) |
| |
|
| | |
| | |
| | new_inputs = [] |
| | for inp in new_model.graph.input: |
| | if inp.name != "feature/config": |
| | new_inputs.append(inp) |
| | del new_model.graph.input[:] |
| | new_model.graph.input.extend(new_inputs) |
| |
|
| | |
| | new_opsets = [] |
| | for op in new_model.opset_import: |
| | if op.domain != "com.microsoft.oneocr": |
| | new_opsets.append(op) |
| | del new_model.opset_import[:] |
| | new_model.opset_import.extend(new_opsets) |
| |
|
| | |
| | try: |
| | onnx.checker.check_model(new_model) |
| | print("Model validation passed!") |
| | except Exception as e: |
| | print(f"Model validation warning: {e}") |
| |
|
| | |
| | modified_path = "temp/model_11_modified.onnx" |
| | Path("temp").mkdir(exist_ok=True) |
| | onnx.save(new_model, modified_path) |
| | print(f"Saved modified model to {modified_path}") |
| |
|
| | |
| | print(f"\n--- Testing inference ---") |
| |
|
| | |
| | try: |
| | sess_orig = ort.InferenceSession(str(model_path)) |
| | print("Original model loaded (unexpected!)") |
| | except Exception as e: |
| | print(f"Original model failed (expected): {str(e)[:100]}") |
| |
|
| | |
| | try: |
| | sess_mod = ort.InferenceSession(modified_path) |
| | print("Modified model loaded successfully!") |
| | |
| | |
| | test_input = np.zeros((1, 21, 1, 1), dtype=np.float32) |
| | result = sess_mod.run(None, {"data": test_input}) |
| | print(f"Zero input → softmax: {result[0]}") |
| | |
| | |
| | test_input = np.random.randn(1, 21, 1, 1).astype(np.float32) * 0.5 |
| | result = sess_mod.run(None, {"data": test_input}) |
| | print(f"Random input → softmax: {result[0]}") |
| | |
| | |
| | test_input = np.array([ |
| | 0.9, 0.1, 0.05, 0.02, 0.01, 0.3, 0.7, 0.6, 0.4, 0.5, |
| | 0.3, 0.01, 0.02, 0.01, 0.01, 0.01, 0.02, 0.01, 0.01, 0.01, 0.8 |
| | ], dtype=np.float32).reshape(1, 21, 1, 1) |
| | result = sess_mod.run(None, {"data": test_input}) |
| | print(f"Typical scores → softmax: {result[0]}") |
| | |
| | except Exception as e: |
| | print(f"Modified model failed: {e}") |
| | import traceback |
| | traceback.print_exc() |
| |
|
| | |
| | print(f"\n--- Testing with ReLU after feature extraction ---") |
| | new_model2 = copy.deepcopy(model) |
| | new_inits2 = [] |
| | for init in new_model2.graph.initializer: |
| | if init.name == "feature/config": |
| | continue |
| | new_inits2.append(init) |
| | new_inits2.append(numpy_helper.from_array(W_fe.T, name="fe_weight")) |
| | new_inits2.append(numpy_helper.from_array(b_fe, name="fe_bias")) |
| | del new_model2.graph.initializer[:] |
| | new_model2.graph.initializer.extend(new_inits2) |
| |
|
| | new_nodes2 = [] |
| | for node in new_model2.graph.node: |
| | if node.op_type == "OneOCRFeatureExtract": |
| | gemm_node = helper.make_node("Gemm", inputs=["29", "fe_weight", "fe_bias"], |
| | outputs=["oneocr_feature_pre"], alpha=1.0, beta=1.0, transB=1) |
| | relu_node = helper.make_node("Relu", inputs=["oneocr_feature_pre"], outputs=["oneocr_feature"]) |
| | new_nodes2.append(gemm_node) |
| | new_nodes2.append(relu_node) |
| | else: |
| | new_nodes2.append(node) |
| | del new_model2.graph.node[:] |
| | new_model2.graph.node.extend(new_nodes2) |
| |
|
| | new_inputs2 = [inp for inp in new_model2.graph.input if inp.name != "feature/config"] |
| | del new_model2.graph.input[:] |
| | new_model2.graph.input.extend(new_inputs2) |
| |
|
| | new_opsets2 = [op for op in new_model2.opset_import if op.domain != "com.microsoft.oneocr"] |
| | del new_model2.opset_import[:] |
| | new_model2.opset_import.extend(new_opsets2) |
| |
|
| | modified_path2 = "temp/model_11_modified_relu.onnx" |
| | onnx.save(new_model2, modified_path2) |
| |
|
| | try: |
| | sess_mod2 = ort.InferenceSession(modified_path2) |
| | test_input = np.zeros((1, 21, 1, 1), dtype=np.float32) |
| | result = sess_mod2.run(None, {"data": test_input}) |
| | print(f"Zero input (Gemm+ReLU) → softmax: {result[0]}") |
| | |
| | test_input = np.random.randn(1, 21, 1, 1).astype(np.float32) * 0.5 |
| | result = sess_mod2.run(None, {"data": test_input}) |
| | print(f"Random input (Gemm+ReLU) → softmax: {result[0]}") |
| | except Exception as e: |
| | print(f"Failed: {e}") |
| |
|