File size: 6,755 Bytes
ce847d4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
"""Replace OneOCRFeatureExtract with standard Gemm and test the model.
This script modifies a model ONNX graph to replace the custom op, then runs inference."""
import onnx
from onnx import numpy_helper, helper, TensorProto
import numpy as np
from pathlib import Path
import onnxruntime as ort
import copy
models_dir = Path("oneocr_extracted/onnx_models")
# Load model_11
model_path = list(models_dir.glob("model_11_*"))[0]
model = onnx.load(str(model_path))
# Extract the config blob (big-endian float32)
for init in model.graph.initializer:
if init.name == "feature/config":
blob = bytes(init.string_data[0])
break
be_arr = np.frombuffer(blob, dtype='>f4').copy()
print(f"Config blob: {len(be_arr)} floats")
# Extract weight matrix and bias: first 1050 = W[21×50], next 50 = bias, rest = metadata
W_fe = be_arr[:1050].reshape(21, 50).astype(np.float32)
b_fe = be_arr[1050:1100].astype(np.float32)
metadata = be_arr[1100:]
print(f"W: {W_fe.shape}, b: {b_fe.shape}, metadata: {metadata.shape}")
print(f"Metadata values: {metadata}")
# Now build a modified model:
# Replace OneOCRFeatureExtract node with Gemm node
# OneOCRFeatureExtract takes ['29', 'feature/config'] → ['oneocr_feature']
# Replace with: Gemm(['29', 'fe_weight', 'fe_bias']) → ['oneocr_feature']
new_model = copy.deepcopy(model)
# Remove the feature/config initializer and add W, b initializers
new_inits = []
for init in new_model.graph.initializer:
if init.name == "feature/config":
continue
new_inits.append(init)
# Add the extracted weights as initializers
W_init = numpy_helper.from_array(W_fe.T, name="fe_weight") # transB=1: [50, 21]
b_init = numpy_helper.from_array(b_fe, name="fe_bias")
new_inits.append(W_init)
new_inits.append(b_init)
del new_model.graph.initializer[:]
new_model.graph.initializer.extend(new_inits)
# Replace the OneOCRFeatureExtract node with Gemm
new_nodes = []
for node in new_model.graph.node:
if node.op_type == "OneOCRFeatureExtract":
# Input '29' has shape [1, 21], output 'oneocr_feature' should be [1, 50]
gemm_node = helper.make_node(
"Gemm",
inputs=["29", "fe_weight", "fe_bias"],
outputs=["oneocr_feature"],
alpha=1.0,
beta=1.0,
transB=1,
)
new_nodes.append(gemm_node)
print(f"Replaced OneOCRFeatureExtract with Gemm(29 @ W.T + b)")
else:
new_nodes.append(node)
del new_model.graph.node[:]
new_model.graph.node.extend(new_nodes)
# Also need to handle the input value_infos for the new weights
# Remove feature/config from graph inputs if present
new_inputs = []
for inp in new_model.graph.input:
if inp.name != "feature/config":
new_inputs.append(inp)
del new_model.graph.input[:]
new_model.graph.input.extend(new_inputs)
# Fix opset — remove com.microsoft.oneocr domain
new_opsets = []
for op in new_model.opset_import:
if op.domain != "com.microsoft.oneocr":
new_opsets.append(op)
del new_model.opset_import[:]
new_model.opset_import.extend(new_opsets)
# Validate
try:
onnx.checker.check_model(new_model)
print("Model validation passed!")
except Exception as e:
print(f"Model validation warning: {e}")
# Save modified model
modified_path = "temp/model_11_modified.onnx"
Path("temp").mkdir(exist_ok=True)
onnx.save(new_model, modified_path)
print(f"Saved modified model to {modified_path}")
# Try to run inference
print(f"\n--- Testing inference ---")
# Test with original model first (will fail due to custom op)
try:
sess_orig = ort.InferenceSession(str(model_path))
print("Original model loaded (unexpected!)")
except Exception as e:
print(f"Original model failed (expected): {str(e)[:100]}")
# Test with modified model
try:
sess_mod = ort.InferenceSession(modified_path)
print("Modified model loaded successfully!")
# Run with test input
test_input = np.zeros((1, 21, 1, 1), dtype=np.float32)
result = sess_mod.run(None, {"data": test_input})
print(f"Zero input → softmax: {result[0]}")
# Random input
test_input = np.random.randn(1, 21, 1, 1).astype(np.float32) * 0.5
result = sess_mod.run(None, {"data": test_input})
print(f"Random input → softmax: {result[0]}")
# Typical CTC features (normalized scores)
test_input = np.array([
0.9, 0.1, 0.05, 0.02, 0.01, 0.3, 0.7, 0.6, 0.4, 0.5,
0.3, 0.01, 0.02, 0.01, 0.01, 0.01, 0.02, 0.01, 0.01, 0.01, 0.8
], dtype=np.float32).reshape(1, 21, 1, 1)
result = sess_mod.run(None, {"data": test_input})
print(f"Typical scores → softmax: {result[0]}")
except Exception as e:
print(f"Modified model failed: {e}")
import traceback
traceback.print_exc()
# Also try with ReLU after Gemm (maybe the custom op includes activation)
print(f"\n--- Testing with ReLU after feature extraction ---")
new_model2 = copy.deepcopy(model)
new_inits2 = []
for init in new_model2.graph.initializer:
if init.name == "feature/config":
continue
new_inits2.append(init)
new_inits2.append(numpy_helper.from_array(W_fe.T, name="fe_weight"))
new_inits2.append(numpy_helper.from_array(b_fe, name="fe_bias"))
del new_model2.graph.initializer[:]
new_model2.graph.initializer.extend(new_inits2)
new_nodes2 = []
for node in new_model2.graph.node:
if node.op_type == "OneOCRFeatureExtract":
gemm_node = helper.make_node("Gemm", inputs=["29", "fe_weight", "fe_bias"],
outputs=["oneocr_feature_pre"], alpha=1.0, beta=1.0, transB=1)
relu_node = helper.make_node("Relu", inputs=["oneocr_feature_pre"], outputs=["oneocr_feature"])
new_nodes2.append(gemm_node)
new_nodes2.append(relu_node)
else:
new_nodes2.append(node)
del new_model2.graph.node[:]
new_model2.graph.node.extend(new_nodes2)
new_inputs2 = [inp for inp in new_model2.graph.input if inp.name != "feature/config"]
del new_model2.graph.input[:]
new_model2.graph.input.extend(new_inputs2)
new_opsets2 = [op for op in new_model2.opset_import if op.domain != "com.microsoft.oneocr"]
del new_model2.opset_import[:]
new_model2.opset_import.extend(new_opsets2)
modified_path2 = "temp/model_11_modified_relu.onnx"
onnx.save(new_model2, modified_path2)
try:
sess_mod2 = ort.InferenceSession(modified_path2)
test_input = np.zeros((1, 21, 1, 1), dtype=np.float32)
result = sess_mod2.run(None, {"data": test_input})
print(f"Zero input (Gemm+ReLU) → softmax: {result[0]}")
test_input = np.random.randn(1, 21, 1, 1).astype(np.float32) * 0.5
result = sess_mod2.run(None, {"data": test_input})
print(f"Random input (Gemm+ReLU) → softmax: {result[0]}")
except Exception as e:
print(f"Failed: {e}")
|