File size: 6,755 Bytes
ce847d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
"""Replace OneOCRFeatureExtract with standard Gemm and test the model.
This script modifies a model ONNX graph to replace the custom op, then runs inference."""
import onnx
from onnx import numpy_helper, helper, TensorProto
import numpy as np
from pathlib import Path
import onnxruntime as ort
import copy

models_dir = Path("oneocr_extracted/onnx_models")

# Load model_11
model_path = list(models_dir.glob("model_11_*"))[0]
model = onnx.load(str(model_path))

# Extract the config blob (big-endian float32)
for init in model.graph.initializer:
    if init.name == "feature/config":
        blob = bytes(init.string_data[0])
        break

be_arr = np.frombuffer(blob, dtype='>f4').copy()
print(f"Config blob: {len(be_arr)} floats")

# Extract weight matrix and bias: first 1050 = W[21×50], next 50 = bias, rest = metadata
W_fe = be_arr[:1050].reshape(21, 50).astype(np.float32)
b_fe = be_arr[1050:1100].astype(np.float32)
metadata = be_arr[1100:]
print(f"W: {W_fe.shape}, b: {b_fe.shape}, metadata: {metadata.shape}")
print(f"Metadata values: {metadata}")

# Now build a modified model:
# Replace OneOCRFeatureExtract node with Gemm node
# OneOCRFeatureExtract takes ['29', 'feature/config'] → ['oneocr_feature']
# Replace with: Gemm(['29', 'fe_weight', 'fe_bias']) → ['oneocr_feature']

new_model = copy.deepcopy(model)

# Remove the feature/config initializer and add W, b initializers
new_inits = []
for init in new_model.graph.initializer:
    if init.name == "feature/config":
        continue
    new_inits.append(init)

# Add the extracted weights as initializers
W_init = numpy_helper.from_array(W_fe.T, name="fe_weight")  # transB=1: [50, 21]
b_init = numpy_helper.from_array(b_fe, name="fe_bias")
new_inits.append(W_init)
new_inits.append(b_init)

del new_model.graph.initializer[:]
new_model.graph.initializer.extend(new_inits)

# Replace the OneOCRFeatureExtract node with Gemm
new_nodes = []
for node in new_model.graph.node:
    if node.op_type == "OneOCRFeatureExtract":
        # Input '29' has shape [1, 21], output 'oneocr_feature' should be [1, 50]
        gemm_node = helper.make_node(
            "Gemm",
            inputs=["29", "fe_weight", "fe_bias"],
            outputs=["oneocr_feature"],
            alpha=1.0,
            beta=1.0,
            transB=1,
        )
        new_nodes.append(gemm_node)
        print(f"Replaced OneOCRFeatureExtract with Gemm(29 @ W.T + b)")
    else:
        new_nodes.append(node)

del new_model.graph.node[:]
new_model.graph.node.extend(new_nodes)

# Also need to handle the input value_infos for the new weights
# Remove feature/config from graph inputs if present
new_inputs = []
for inp in new_model.graph.input:
    if inp.name != "feature/config":
        new_inputs.append(inp)
del new_model.graph.input[:]
new_model.graph.input.extend(new_inputs)

# Fix opset — remove com.microsoft.oneocr domain
new_opsets = []
for op in new_model.opset_import:
    if op.domain != "com.microsoft.oneocr":
        new_opsets.append(op)
del new_model.opset_import[:]
new_model.opset_import.extend(new_opsets)

# Validate
try:
    onnx.checker.check_model(new_model)
    print("Model validation passed!")
except Exception as e:
    print(f"Model validation warning: {e}")

# Save modified model
modified_path = "temp/model_11_modified.onnx"
Path("temp").mkdir(exist_ok=True)
onnx.save(new_model, modified_path)
print(f"Saved modified model to {modified_path}")

# Try to run inference
print(f"\n--- Testing inference ---")

# Test with original model first (will fail due to custom op)
try:
    sess_orig = ort.InferenceSession(str(model_path))
    print("Original model loaded (unexpected!)")
except Exception as e:
    print(f"Original model failed (expected): {str(e)[:100]}")

# Test with modified model
try:
    sess_mod = ort.InferenceSession(modified_path)
    print("Modified model loaded successfully!")
    
    # Run with test input
    test_input = np.zeros((1, 21, 1, 1), dtype=np.float32)
    result = sess_mod.run(None, {"data": test_input})
    print(f"Zero input → softmax: {result[0]}")
    
    # Random input
    test_input = np.random.randn(1, 21, 1, 1).astype(np.float32) * 0.5
    result = sess_mod.run(None, {"data": test_input})
    print(f"Random input → softmax: {result[0]}")
    
    # Typical CTC features (normalized scores)
    test_input = np.array([
        0.9, 0.1, 0.05, 0.02, 0.01, 0.3, 0.7, 0.6, 0.4, 0.5,
        0.3, 0.01, 0.02, 0.01, 0.01, 0.01, 0.02, 0.01, 0.01, 0.01, 0.8
    ], dtype=np.float32).reshape(1, 21, 1, 1)
    result = sess_mod.run(None, {"data": test_input})
    print(f"Typical scores → softmax: {result[0]}")
    
except Exception as e:
    print(f"Modified model failed: {e}")
    import traceback
    traceback.print_exc()

# Also try with ReLU after Gemm (maybe the custom op includes activation)
print(f"\n--- Testing with ReLU after feature extraction ---")
new_model2 = copy.deepcopy(model)
new_inits2 = []
for init in new_model2.graph.initializer:
    if init.name == "feature/config":
        continue
    new_inits2.append(init)
new_inits2.append(numpy_helper.from_array(W_fe.T, name="fe_weight"))
new_inits2.append(numpy_helper.from_array(b_fe, name="fe_bias"))
del new_model2.graph.initializer[:]
new_model2.graph.initializer.extend(new_inits2)

new_nodes2 = []
for node in new_model2.graph.node:
    if node.op_type == "OneOCRFeatureExtract":
        gemm_node = helper.make_node("Gemm", inputs=["29", "fe_weight", "fe_bias"],
                                      outputs=["oneocr_feature_pre"], alpha=1.0, beta=1.0, transB=1)
        relu_node = helper.make_node("Relu", inputs=["oneocr_feature_pre"], outputs=["oneocr_feature"])
        new_nodes2.append(gemm_node)
        new_nodes2.append(relu_node)
    else:
        new_nodes2.append(node)
del new_model2.graph.node[:]
new_model2.graph.node.extend(new_nodes2)

new_inputs2 = [inp for inp in new_model2.graph.input if inp.name != "feature/config"]
del new_model2.graph.input[:]
new_model2.graph.input.extend(new_inputs2)

new_opsets2 = [op for op in new_model2.opset_import if op.domain != "com.microsoft.oneocr"]
del new_model2.opset_import[:]
new_model2.opset_import.extend(new_opsets2)

modified_path2 = "temp/model_11_modified_relu.onnx"
onnx.save(new_model2, modified_path2)

try:
    sess_mod2 = ort.InferenceSession(modified_path2)
    test_input = np.zeros((1, 21, 1, 1), dtype=np.float32)
    result = sess_mod2.run(None, {"data": test_input})
    print(f"Zero input (Gemm+ReLU) → softmax: {result[0]}")
    
    test_input = np.random.randn(1, 21, 1, 1).astype(np.float32) * 0.5
    result = sess_mod2.run(None, {"data": test_input})
    print(f"Random input (Gemm+ReLU) → softmax: {result[0]}")
except Exception as e:
    print(f"Failed: {e}")