"""Replace OneOCRFeatureExtract with standard Gemm and test the model. This script modifies a model ONNX graph to replace the custom op, then runs inference.""" import onnx from onnx import numpy_helper, helper, TensorProto import numpy as np from pathlib import Path import onnxruntime as ort import copy models_dir = Path("oneocr_extracted/onnx_models") # Load model_11 model_path = list(models_dir.glob("model_11_*"))[0] model = onnx.load(str(model_path)) # Extract the config blob (big-endian float32) for init in model.graph.initializer: if init.name == "feature/config": blob = bytes(init.string_data[0]) break be_arr = np.frombuffer(blob, dtype='>f4').copy() print(f"Config blob: {len(be_arr)} floats") # Extract weight matrix and bias: first 1050 = W[21×50], next 50 = bias, rest = metadata W_fe = be_arr[:1050].reshape(21, 50).astype(np.float32) b_fe = be_arr[1050:1100].astype(np.float32) metadata = be_arr[1100:] print(f"W: {W_fe.shape}, b: {b_fe.shape}, metadata: {metadata.shape}") print(f"Metadata values: {metadata}") # Now build a modified model: # Replace OneOCRFeatureExtract node with Gemm node # OneOCRFeatureExtract takes ['29', 'feature/config'] → ['oneocr_feature'] # Replace with: Gemm(['29', 'fe_weight', 'fe_bias']) → ['oneocr_feature'] new_model = copy.deepcopy(model) # Remove the feature/config initializer and add W, b initializers new_inits = [] for init in new_model.graph.initializer: if init.name == "feature/config": continue new_inits.append(init) # Add the extracted weights as initializers W_init = numpy_helper.from_array(W_fe.T, name="fe_weight") # transB=1: [50, 21] b_init = numpy_helper.from_array(b_fe, name="fe_bias") new_inits.append(W_init) new_inits.append(b_init) del new_model.graph.initializer[:] new_model.graph.initializer.extend(new_inits) # Replace the OneOCRFeatureExtract node with Gemm new_nodes = [] for node in new_model.graph.node: if node.op_type == "OneOCRFeatureExtract": # Input '29' has shape [1, 21], output 'oneocr_feature' should be [1, 50] gemm_node = helper.make_node( "Gemm", inputs=["29", "fe_weight", "fe_bias"], outputs=["oneocr_feature"], alpha=1.0, beta=1.0, transB=1, ) new_nodes.append(gemm_node) print(f"Replaced OneOCRFeatureExtract with Gemm(29 @ W.T + b)") else: new_nodes.append(node) del new_model.graph.node[:] new_model.graph.node.extend(new_nodes) # Also need to handle the input value_infos for the new weights # Remove feature/config from graph inputs if present new_inputs = [] for inp in new_model.graph.input: if inp.name != "feature/config": new_inputs.append(inp) del new_model.graph.input[:] new_model.graph.input.extend(new_inputs) # Fix opset — remove com.microsoft.oneocr domain new_opsets = [] for op in new_model.opset_import: if op.domain != "com.microsoft.oneocr": new_opsets.append(op) del new_model.opset_import[:] new_model.opset_import.extend(new_opsets) # Validate try: onnx.checker.check_model(new_model) print("Model validation passed!") except Exception as e: print(f"Model validation warning: {e}") # Save modified model modified_path = "temp/model_11_modified.onnx" Path("temp").mkdir(exist_ok=True) onnx.save(new_model, modified_path) print(f"Saved modified model to {modified_path}") # Try to run inference print(f"\n--- Testing inference ---") # Test with original model first (will fail due to custom op) try: sess_orig = ort.InferenceSession(str(model_path)) print("Original model loaded (unexpected!)") except Exception as e: print(f"Original model failed (expected): {str(e)[:100]}") # Test with modified model try: sess_mod = ort.InferenceSession(modified_path) print("Modified model loaded successfully!") # Run with test input test_input = np.zeros((1, 21, 1, 1), dtype=np.float32) result = sess_mod.run(None, {"data": test_input}) print(f"Zero input → softmax: {result[0]}") # Random input test_input = np.random.randn(1, 21, 1, 1).astype(np.float32) * 0.5 result = sess_mod.run(None, {"data": test_input}) print(f"Random input → softmax: {result[0]}") # Typical CTC features (normalized scores) test_input = np.array([ 0.9, 0.1, 0.05, 0.02, 0.01, 0.3, 0.7, 0.6, 0.4, 0.5, 0.3, 0.01, 0.02, 0.01, 0.01, 0.01, 0.02, 0.01, 0.01, 0.01, 0.8 ], dtype=np.float32).reshape(1, 21, 1, 1) result = sess_mod.run(None, {"data": test_input}) print(f"Typical scores → softmax: {result[0]}") except Exception as e: print(f"Modified model failed: {e}") import traceback traceback.print_exc() # Also try with ReLU after Gemm (maybe the custom op includes activation) print(f"\n--- Testing with ReLU after feature extraction ---") new_model2 = copy.deepcopy(model) new_inits2 = [] for init in new_model2.graph.initializer: if init.name == "feature/config": continue new_inits2.append(init) new_inits2.append(numpy_helper.from_array(W_fe.T, name="fe_weight")) new_inits2.append(numpy_helper.from_array(b_fe, name="fe_bias")) del new_model2.graph.initializer[:] new_model2.graph.initializer.extend(new_inits2) new_nodes2 = [] for node in new_model2.graph.node: if node.op_type == "OneOCRFeatureExtract": gemm_node = helper.make_node("Gemm", inputs=["29", "fe_weight", "fe_bias"], outputs=["oneocr_feature_pre"], alpha=1.0, beta=1.0, transB=1) relu_node = helper.make_node("Relu", inputs=["oneocr_feature_pre"], outputs=["oneocr_feature"]) new_nodes2.append(gemm_node) new_nodes2.append(relu_node) else: new_nodes2.append(node) del new_model2.graph.node[:] new_model2.graph.node.extend(new_nodes2) new_inputs2 = [inp for inp in new_model2.graph.input if inp.name != "feature/config"] del new_model2.graph.input[:] new_model2.graph.input.extend(new_inputs2) new_opsets2 = [op for op in new_model2.opset_import if op.domain != "com.microsoft.oneocr"] del new_model2.opset_import[:] new_model2.opset_import.extend(new_opsets2) modified_path2 = "temp/model_11_modified_relu.onnx" onnx.save(new_model2, modified_path2) try: sess_mod2 = ort.InferenceSession(modified_path2) test_input = np.zeros((1, 21, 1, 1), dtype=np.float32) result = sess_mod2.run(None, {"data": test_input}) print(f"Zero input (Gemm+ReLU) → softmax: {result[0]}") test_input = np.random.randn(1, 21, 1, 1).astype(np.float32) * 0.5 result = sess_mod2.run(None, {"data": test_input}) print(f"Random input (Gemm+ReLU) → softmax: {result[0]}") except Exception as e: print(f"Failed: {e}")