oneocr / _archive /unlock_models.py
OneOCR Dev
OneOCR - reverse engineering complete, ONNX pipeline 53% match rate
ce847d4
"""
Unlock ALL OneOCRFeatureExtract models (11-33).
Replaces the custom `OneOCRFeatureExtract` op (domain: com.microsoft.oneocr)
with a standard ONNX `Gemm` node. The weights are extracted from the
big-endian float32 config blob stored as a STRING tensor.
Config blob structure (for small/medium LM models 11-32):
- W[input_dim × output_dim] as big-endian float32
- b[output_dim] as big-endian float32
- metadata[remaining] containing dimensions, flags, etc.
Usage:
python unlock_models.py # unlock all 11-33
python unlock_models.py 11 22 33 # unlock specific models
"""
import onnx
from onnx import numpy_helper, helper
import numpy as np
from pathlib import Path
import copy
import sys
def extract_fe_weights(model: onnx.ModelProto) -> tuple[np.ndarray, np.ndarray, dict]:
"""Extract weights from OneOCRFeatureExtract config blob.
The config blob is stored as big-endian float32:
W[in_dim × out_dim] + b[out_dim] + metadata
The metadata tail contains the dimensions as float values.
Returns:
(weight_matrix, bias, metadata_dict)
"""
# Find the feature/config initializer
config_blob = None
for init in model.graph.initializer:
if init.name == "feature/config":
config_blob = bytes(init.string_data[0]) if init.string_data else bytes(init.raw_data)
break
if config_blob is None:
raise ValueError("No feature/config initializer found")
# Parse as big-endian float32
be_arr = np.frombuffer(config_blob, dtype='>f4').copy()
# Find the OneOCRFeatureExtract node to determine input/output dimensions
fe_node = None
for node in model.graph.node:
if node.op_type == "OneOCRFeatureExtract":
fe_node = node
break
if fe_node is None:
raise ValueError("No OneOCRFeatureExtract node found")
# Get input/output dimensions from the graph
# Input comes from a normalization pipeline, output goes to Gemm
in_dim = None
out_dim = None
# Try to get dims from metadata at the end of blob
# Pattern: [..., in_dim, out_dim, num_classes, ...] near the end
for i in range(len(be_arr) - 10, len(be_arr)):
val = be_arr[i]
if val == 21.0 and i + 1 < len(be_arr) and be_arr[i + 1] in [50.0, 51.0]:
in_dim = int(val)
out_dim = int(be_arr[i + 1])
break
# Fallback: infer from graph inputs
if in_dim is None:
for graph_input in model.graph.input:
if graph_input.name == "data":
shape = [d.dim_value for d in graph_input.type.tensor_type.shape.dim]
if len(shape) >= 2:
in_dim = shape[1] if shape[1] > 0 else 21
break
if out_dim is None:
# Find the Gemm after OneOCRFeatureExtract output
fe_output = fe_node.output[0]
for node in model.graph.node:
if node.op_type == "Gemm" and fe_output in node.input:
# The Gemm's weight tells us the output dim
weight_name = node.input[1]
for init in model.graph.initializer:
if init.name == weight_name:
W = numpy_helper.to_array(init)
out_dim = W.shape[0] if len(W.shape) == 2 else W.shape[1]
break
break
if in_dim is None or out_dim is None:
raise ValueError(f"Could not determine dimensions: in={in_dim}, out={out_dim}")
# Extract weights: first in_dim*out_dim floats = W, next out_dim = b
n_weights = in_dim * out_dim
n_bias = out_dim
if len(be_arr) < n_weights + n_bias:
raise ValueError(f"Config blob too small: {len(be_arr)} < {n_weights + n_bias}")
W = be_arr[:n_weights].reshape(in_dim, out_dim).astype(np.float32)
b = be_arr[n_weights:n_weights + n_bias].astype(np.float32)
metadata = be_arr[n_weights + n_bias:]
meta_dict = {
"in_dim": in_dim,
"out_dim": out_dim,
"total_floats": len(be_arr),
"metadata_floats": len(metadata),
"metadata_values": metadata.tolist(),
}
return W, b, meta_dict
def unlock_model(model_path: Path, output_dir: Path) -> Path:
"""Replace OneOCRFeatureExtract with standard Gemm in an ONNX model.
Args:
model_path: Path to the original ONNX model.
output_dir: Directory to save the modified model.
Returns:
Path to the modified model.
"""
model = onnx.load(str(model_path))
# Check if model uses OneOCRFeatureExtract
has_custom_op = any(
node.op_type == "OneOCRFeatureExtract"
for node in model.graph.node
)
if not has_custom_op:
print(f" {model_path.name}: No OneOCRFeatureExtract — skipping")
return model_path
# Extract weights
try:
W, b, meta = extract_fe_weights(model)
except Exception as e:
print(f" {model_path.name}: Failed to extract weights: {e}")
return model_path
print(f" {model_path.name}: W[{meta['in_dim']}×{meta['out_dim']}] + b[{meta['out_dim']}] "
f"(metadata: {meta['metadata_floats']} floats)")
# Modify the model
new_model = copy.deepcopy(model)
# Find the OneOCRFeatureExtract node
fe_node = None
for node in new_model.graph.node:
if node.op_type == "OneOCRFeatureExtract":
fe_node = node
break
fe_input = fe_node.input[0]
fe_output = fe_node.output[0]
# Replace initializers: remove feature/config, add W and b
new_inits = [init for init in new_model.graph.initializer if init.name != "feature/config"]
new_inits.append(numpy_helper.from_array(W.T, name="fe_weight")) # [out, in] for transB=1
new_inits.append(numpy_helper.from_array(b, name="fe_bias"))
del new_model.graph.initializer[:]
new_model.graph.initializer.extend(new_inits)
# Replace the custom op node with Gemm
new_nodes = []
for node in new_model.graph.node:
if node.op_type == "OneOCRFeatureExtract":
gemm_node = helper.make_node(
"Gemm",
inputs=[fe_input, "fe_weight", "fe_bias"],
outputs=[fe_output],
alpha=1.0,
beta=1.0,
transB=1,
)
new_nodes.append(gemm_node)
else:
new_nodes.append(node)
del new_model.graph.node[:]
new_model.graph.node.extend(new_nodes)
# Clean up inputs (remove feature/config)
new_inputs = [inp for inp in new_model.graph.input if inp.name != "feature/config"]
del new_model.graph.input[:]
new_model.graph.input.extend(new_inputs)
# Remove custom opset domain
new_opsets = [op for op in new_model.opset_import if op.domain != "com.microsoft.oneocr"]
del new_model.opset_import[:]
new_model.opset_import.extend(new_opsets)
# Save
output_dir.mkdir(parents=True, exist_ok=True)
out_name = model_path.stem + "_unlocked.onnx"
out_path = output_dir / out_name
onnx.save(new_model, str(out_path))
# Verify it loads in onnxruntime
try:
import onnxruntime as ort
sess = ort.InferenceSession(str(out_path))
# Quick test with zero input
input_info = sess.get_inputs()
feeds = {}
for inp in input_info:
shape = [d if isinstance(d, int) and d > 0 else 1 for d in inp.shape]
feeds[inp.name] = np.zeros(shape, dtype=np.float32)
result = sess.run(None, feeds)
print(f" ✓ Inference OK — output shapes: {[r.shape for r in result]}")
except Exception as e:
print(f" ✗ Inference failed: {e}")
return out_path
def main():
models_dir = Path("oneocr_extracted/onnx_models")
output_dir = Path("oneocr_extracted/onnx_models_unlocked")
# Determine which models to process
if len(sys.argv) > 1:
indices = [int(x) for x in sys.argv[1:]]
else:
indices = list(range(11, 34)) # models 11-33
print(f"Unlocking {len(indices)} models...")
print(f"Source: {models_dir}")
print(f"Output: {output_dir}")
print()
results = {"success": [], "skip": [], "fail": []}
for idx in indices:
matches = list(models_dir.glob(f"model_{idx:02d}_*"))
if not matches:
print(f" model_{idx:02d}: NOT FOUND")
results["fail"].append(idx)
continue
model_path = matches[0]
try:
out = unlock_model(model_path, output_dir)
if out == model_path:
results["skip"].append(idx)
else:
results["success"].append(idx)
except Exception as e:
print(f" model_{idx:02d}: ERROR — {e}")
results["fail"].append(idx)
# Summary
print(f"\n{'='*60}")
print(f"Results:")
print(f" Unlocked: {len(results['success'])}{results['success']}")
print(f" Skipped: {len(results['skip'])}{results['skip']}")
print(f" Failed: {len(results['fail'])}{results['fail']}")
if __name__ == "__main__":
main()