|
|
""" |
|
|
Unlock ALL OneOCRFeatureExtract models (11-33). |
|
|
|
|
|
Replaces the custom `OneOCRFeatureExtract` op (domain: com.microsoft.oneocr) |
|
|
with a standard ONNX `Gemm` node. The weights are extracted from the |
|
|
big-endian float32 config blob stored as a STRING tensor. |
|
|
|
|
|
Config blob structure (for small/medium LM models 11-32): |
|
|
- W[input_dim × output_dim] as big-endian float32 |
|
|
- b[output_dim] as big-endian float32 |
|
|
- metadata[remaining] containing dimensions, flags, etc. |
|
|
|
|
|
Usage: |
|
|
python unlock_models.py # unlock all 11-33 |
|
|
python unlock_models.py 11 22 33 # unlock specific models |
|
|
""" |
|
|
|
|
|
import onnx |
|
|
from onnx import numpy_helper, helper |
|
|
import numpy as np |
|
|
from pathlib import Path |
|
|
import copy |
|
|
import sys |
|
|
|
|
|
|
|
|
def extract_fe_weights(model: onnx.ModelProto) -> tuple[np.ndarray, np.ndarray, dict]: |
|
|
"""Extract weights from OneOCRFeatureExtract config blob. |
|
|
|
|
|
The config blob is stored as big-endian float32: |
|
|
W[in_dim × out_dim] + b[out_dim] + metadata |
|
|
|
|
|
The metadata tail contains the dimensions as float values. |
|
|
|
|
|
Returns: |
|
|
(weight_matrix, bias, metadata_dict) |
|
|
""" |
|
|
|
|
|
config_blob = None |
|
|
for init in model.graph.initializer: |
|
|
if init.name == "feature/config": |
|
|
config_blob = bytes(init.string_data[0]) if init.string_data else bytes(init.raw_data) |
|
|
break |
|
|
|
|
|
if config_blob is None: |
|
|
raise ValueError("No feature/config initializer found") |
|
|
|
|
|
|
|
|
be_arr = np.frombuffer(config_blob, dtype='>f4').copy() |
|
|
|
|
|
|
|
|
fe_node = None |
|
|
for node in model.graph.node: |
|
|
if node.op_type == "OneOCRFeatureExtract": |
|
|
fe_node = node |
|
|
break |
|
|
|
|
|
if fe_node is None: |
|
|
raise ValueError("No OneOCRFeatureExtract node found") |
|
|
|
|
|
|
|
|
|
|
|
in_dim = None |
|
|
out_dim = None |
|
|
|
|
|
|
|
|
|
|
|
for i in range(len(be_arr) - 10, len(be_arr)): |
|
|
val = be_arr[i] |
|
|
if val == 21.0 and i + 1 < len(be_arr) and be_arr[i + 1] in [50.0, 51.0]: |
|
|
in_dim = int(val) |
|
|
out_dim = int(be_arr[i + 1]) |
|
|
break |
|
|
|
|
|
|
|
|
if in_dim is None: |
|
|
for graph_input in model.graph.input: |
|
|
if graph_input.name == "data": |
|
|
shape = [d.dim_value for d in graph_input.type.tensor_type.shape.dim] |
|
|
if len(shape) >= 2: |
|
|
in_dim = shape[1] if shape[1] > 0 else 21 |
|
|
break |
|
|
|
|
|
if out_dim is None: |
|
|
|
|
|
fe_output = fe_node.output[0] |
|
|
for node in model.graph.node: |
|
|
if node.op_type == "Gemm" and fe_output in node.input: |
|
|
|
|
|
weight_name = node.input[1] |
|
|
for init in model.graph.initializer: |
|
|
if init.name == weight_name: |
|
|
W = numpy_helper.to_array(init) |
|
|
out_dim = W.shape[0] if len(W.shape) == 2 else W.shape[1] |
|
|
break |
|
|
break |
|
|
|
|
|
if in_dim is None or out_dim is None: |
|
|
raise ValueError(f"Could not determine dimensions: in={in_dim}, out={out_dim}") |
|
|
|
|
|
|
|
|
n_weights = in_dim * out_dim |
|
|
n_bias = out_dim |
|
|
|
|
|
if len(be_arr) < n_weights + n_bias: |
|
|
raise ValueError(f"Config blob too small: {len(be_arr)} < {n_weights + n_bias}") |
|
|
|
|
|
W = be_arr[:n_weights].reshape(in_dim, out_dim).astype(np.float32) |
|
|
b = be_arr[n_weights:n_weights + n_bias].astype(np.float32) |
|
|
metadata = be_arr[n_weights + n_bias:] |
|
|
|
|
|
meta_dict = { |
|
|
"in_dim": in_dim, |
|
|
"out_dim": out_dim, |
|
|
"total_floats": len(be_arr), |
|
|
"metadata_floats": len(metadata), |
|
|
"metadata_values": metadata.tolist(), |
|
|
} |
|
|
|
|
|
return W, b, meta_dict |
|
|
|
|
|
|
|
|
def unlock_model(model_path: Path, output_dir: Path) -> Path: |
|
|
"""Replace OneOCRFeatureExtract with standard Gemm in an ONNX model. |
|
|
|
|
|
Args: |
|
|
model_path: Path to the original ONNX model. |
|
|
output_dir: Directory to save the modified model. |
|
|
|
|
|
Returns: |
|
|
Path to the modified model. |
|
|
""" |
|
|
model = onnx.load(str(model_path)) |
|
|
|
|
|
|
|
|
has_custom_op = any( |
|
|
node.op_type == "OneOCRFeatureExtract" |
|
|
for node in model.graph.node |
|
|
) |
|
|
if not has_custom_op: |
|
|
print(f" {model_path.name}: No OneOCRFeatureExtract — skipping") |
|
|
return model_path |
|
|
|
|
|
|
|
|
try: |
|
|
W, b, meta = extract_fe_weights(model) |
|
|
except Exception as e: |
|
|
print(f" {model_path.name}: Failed to extract weights: {e}") |
|
|
return model_path |
|
|
|
|
|
print(f" {model_path.name}: W[{meta['in_dim']}×{meta['out_dim']}] + b[{meta['out_dim']}] " |
|
|
f"(metadata: {meta['metadata_floats']} floats)") |
|
|
|
|
|
|
|
|
new_model = copy.deepcopy(model) |
|
|
|
|
|
|
|
|
fe_node = None |
|
|
for node in new_model.graph.node: |
|
|
if node.op_type == "OneOCRFeatureExtract": |
|
|
fe_node = node |
|
|
break |
|
|
|
|
|
fe_input = fe_node.input[0] |
|
|
fe_output = fe_node.output[0] |
|
|
|
|
|
|
|
|
new_inits = [init for init in new_model.graph.initializer if init.name != "feature/config"] |
|
|
new_inits.append(numpy_helper.from_array(W.T, name="fe_weight")) |
|
|
new_inits.append(numpy_helper.from_array(b, name="fe_bias")) |
|
|
del new_model.graph.initializer[:] |
|
|
new_model.graph.initializer.extend(new_inits) |
|
|
|
|
|
|
|
|
new_nodes = [] |
|
|
for node in new_model.graph.node: |
|
|
if node.op_type == "OneOCRFeatureExtract": |
|
|
gemm_node = helper.make_node( |
|
|
"Gemm", |
|
|
inputs=[fe_input, "fe_weight", "fe_bias"], |
|
|
outputs=[fe_output], |
|
|
alpha=1.0, |
|
|
beta=1.0, |
|
|
transB=1, |
|
|
) |
|
|
new_nodes.append(gemm_node) |
|
|
else: |
|
|
new_nodes.append(node) |
|
|
del new_model.graph.node[:] |
|
|
new_model.graph.node.extend(new_nodes) |
|
|
|
|
|
|
|
|
new_inputs = [inp for inp in new_model.graph.input if inp.name != "feature/config"] |
|
|
del new_model.graph.input[:] |
|
|
new_model.graph.input.extend(new_inputs) |
|
|
|
|
|
|
|
|
new_opsets = [op for op in new_model.opset_import if op.domain != "com.microsoft.oneocr"] |
|
|
del new_model.opset_import[:] |
|
|
new_model.opset_import.extend(new_opsets) |
|
|
|
|
|
|
|
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
out_name = model_path.stem + "_unlocked.onnx" |
|
|
out_path = output_dir / out_name |
|
|
onnx.save(new_model, str(out_path)) |
|
|
|
|
|
|
|
|
try: |
|
|
import onnxruntime as ort |
|
|
sess = ort.InferenceSession(str(out_path)) |
|
|
|
|
|
|
|
|
input_info = sess.get_inputs() |
|
|
feeds = {} |
|
|
for inp in input_info: |
|
|
shape = [d if isinstance(d, int) and d > 0 else 1 for d in inp.shape] |
|
|
feeds[inp.name] = np.zeros(shape, dtype=np.float32) |
|
|
|
|
|
result = sess.run(None, feeds) |
|
|
print(f" ✓ Inference OK — output shapes: {[r.shape for r in result]}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f" ✗ Inference failed: {e}") |
|
|
|
|
|
return out_path |
|
|
|
|
|
|
|
|
def main(): |
|
|
models_dir = Path("oneocr_extracted/onnx_models") |
|
|
output_dir = Path("oneocr_extracted/onnx_models_unlocked") |
|
|
|
|
|
|
|
|
if len(sys.argv) > 1: |
|
|
indices = [int(x) for x in sys.argv[1:]] |
|
|
else: |
|
|
indices = list(range(11, 34)) |
|
|
|
|
|
print(f"Unlocking {len(indices)} models...") |
|
|
print(f"Source: {models_dir}") |
|
|
print(f"Output: {output_dir}") |
|
|
print() |
|
|
|
|
|
results = {"success": [], "skip": [], "fail": []} |
|
|
|
|
|
for idx in indices: |
|
|
matches = list(models_dir.glob(f"model_{idx:02d}_*")) |
|
|
if not matches: |
|
|
print(f" model_{idx:02d}: NOT FOUND") |
|
|
results["fail"].append(idx) |
|
|
continue |
|
|
|
|
|
model_path = matches[0] |
|
|
try: |
|
|
out = unlock_model(model_path, output_dir) |
|
|
if out == model_path: |
|
|
results["skip"].append(idx) |
|
|
else: |
|
|
results["success"].append(idx) |
|
|
except Exception as e: |
|
|
print(f" model_{idx:02d}: ERROR — {e}") |
|
|
results["fail"].append(idx) |
|
|
|
|
|
|
|
|
print(f"\n{'='*60}") |
|
|
print(f"Results:") |
|
|
print(f" Unlocked: {len(results['success'])} — {results['success']}") |
|
|
print(f" Skipped: {len(results['skip'])} — {results['skip']}") |
|
|
print(f" Failed: {len(results['fail'])} — {results['fail']}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|